示例#1
0
def test_reduce_psi_psi():
    symbol_table = {
        '_i0':
        ast.SymbolNode(ast.NodeSymbol.INDEX, (), None, (0, 3, 1)),
        '_a1':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (1, ), None, (ast.Node(
            (ast.NodeSymbol.INDEX, ), (), ('_i0', ), ()), )),
        '_a2':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, ), None, (1, 2, 3)),
        '_a3':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (1, 2, 3, 4), None, None),
    }
    tree = ast.Node((ast.NodeSymbol.PSI, ), (0, ), (),
                    (ast.Node((ast.NodeSymbol.ARRAY, ), (1, ), ('_a1', ), ()),
                     ast.Node((ast.NodeSymbol.PSI, ), (4, ), (),
                              (ast.Node((ast.NodeSymbol.ARRAY, ), (3, ),
                                        ('_a2', ), ()),
                               ast.Node((ast.NodeSymbol.ARRAY, ), (1, 2, 3, 4),
                                        ('_a3', ), ())))))

    expected_symbol_table = {
        **symbol_table, '_a4':
        ast.SymbolNode(
            ast.NodeSymbol.ARRAY, (4, ), None,
            (1, 2, 3, ast.Node((ast.NodeSymbol.INDEX, ), (), ('_i0', ), ())))
    }
    expected_tree = ast.Node((ast.NodeSymbol.PSI, ), (0, ), (), (ast.Node(
        (ast.NodeSymbol.ARRAY, ), (4, ), ('_a4', ),
        ()), ast.Node((ast.NodeSymbol.ARRAY, ), (1, 2, 3, 4), ('_a3', ), ())))

    testing.assert_transformation(tree, symbol_table, expected_tree,
                                  expected_symbol_table, dnf._reduce_psi_psi)
示例#2
0
def test_dimension_operation():
    symbol_table = {'_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None)}
    node = ast.Node((ast.NodeSymbol.TRANSPOSE,), (4, 3), (), (
        ast.Node((ast.NodeSymbol.ARRAY,), (3, 4), ('_a1',), ()),))
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    assert shape.dimension(context) == 2
示例#3
0
def test_array_transpose_default():
    expression = LazyArray(name='A', shape=(2, 3)).transpose()
    node = ast.Node((ast.NodeSymbol.TRANSPOSE,), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),))
    symbol_table = {'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None)}
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    testing.assert_context_equal(context, expression.context)
示例#4
0
def test_ast_select_node_shape(selection, result_shape):
    node = ast.Node((ast.NodeSymbol.PLUS, ), (1, 2, 3), (), (
        ast.Node((ast.NodeSymbol.ARRAY, ), (4, 5, 6), (0, ), ()),
        ast.Node((ast.NodeSymbol.TRANSPOSE, ), (7, 8, 9), (1, ), (ast.Node(
            (ast.NodeSymbol.ARRAY, ), (10, 11, 12), (2, ), ()), )),
    ))

    context = ast.create_context(ast=node)
    assert ast.select_node_shape(context, selection) == result_shape
示例#5
0
def test_array_single_array_symbolic():
    expression = LazyArray(name='A', shape=('n', 3))
    node = ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ())
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (ast.Node((ast.NodeSymbol.ARRAY,), (), ('n',), ()), 3), None, None),
        'n': ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, None),
    }
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    testing.assert_context_equal(context, expression.context)
示例#6
0
def test_array_reduce(symbol):
    expression = LazyArray(name='A', shape=(2, 3)).reduce(symbol)

    expected_tree = ast.Node((ast.NodeSymbol.REDUCE, LazyArray.OPPERATION_MAP[symbol]), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),))
    expected_symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None),
    }
    expected_context = ast.create_context(ast=expected_tree, symbol_table=expected_symbol_table)

    testing.assert_context_equal(expected_context, expression.context)
示例#7
0
def test_ast_select_node(selection, result_node):
    node = ast.Node((ast.NodeSymbol.PLUS, ), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY, ), None, (0, ), ()),
        ast.Node((ast.NodeSymbol.TRANSPOSE, ), None, (1, ), (ast.Node(
            (ast.NodeSymbol.ARRAY, ), None, (2, ), ()), )),
    ))

    context = ast.create_context(ast=node)
    result_context = ast.create_context(ast=result_node)
    testing.assert_context_equal(ast.select_node(context, selection),
                                 result_context)
示例#8
0
def test_not_matches_rule_nested():
    tree = ast.Node((ast.NodeSymbol.TRANSPOSE, ), None, (), (ast.Node(
        (ast.NodeSymbol.ARRAY, ), (2, 3), ('A', ), ()), ))
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None,
                            (1, 2, 3, 4, 5, 6))
    }
    context = ast.create_context(ast=tree, symbol_table=symbol_table)

    rule = ((ast.NodeSymbol.TRANSPOSE, ), (((ast.NodeSymbol.TRANSPOSE, ), ), ))
    assert not dnf.matches_rule(rule, context)
示例#9
0
def test_array_transpose_with_vector():
    expression = LazyArray(name='A', shape=(2, 3)).transpose([1, 0])
    node = ast.Node((ast.NodeSymbol.TRANSPOSEV,), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('_a1',), ()),
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),))
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None),
        '_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2,), None, (1, 0)),
    }
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    testing.assert_context_equal(context, expression.context)
示例#10
0
def test_array_index_tuple():
    expression = LazyArray(name='A', shape=(2, 3))[1, 0]
    node = ast.Node((ast.NodeSymbol.PSI,), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('_a1',), ()),
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),))
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None),
        '_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2,), None, (1, 0)),
    }
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    testing.assert_context_equal(context, expression.context)
示例#11
0
def test_array_addition():
    expression = LazyArray(name='A', shape=(2, 3)) + LazyArray(name='B', shape=(2, 3))
    tree = ast.Node((ast.NodeSymbol.PLUS,), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('B',), ())))
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None),
        'B': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None)
    }
    context = ast.create_context(ast=tree, symbol_table=symbol_table)

    testing.assert_context_equal(context, expression.context)
示例#12
0
def test_array_index_stride_reverse():
    expression = LazyArray(name='A', shape=(2, 3))[1:2:-1]
    tree = ast.Node(ast.NodeSymbol.PSI, None, (), (
                    ast.Node(ast.NodeSymbol.ARRAY, None, ('_a2',), ()),
                    ast.Node(ast.NodeSymbol.ARRAY, None, ('A',), ())))
    symbol_table = {
        'A': SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None),
        'n': SymbolNode(ast.NodeSymbol.ARRAY, (), None, None),
        '_a2': SymbolNode(ast.NodeSymbol.ARRAY, (1,), None, (Node(ast.NodeSymbol.ARRAY, (), 'n'),)),
    }
    context = ast.create_context(ast=tree, symbol_table=symbol_table)

    testing.assert_context_equal(context, expression.context)
示例#13
0
def test_array_inner_product(left_symbol, right_symbol):
    expression = LazyArray(name='A', shape=(2, 3)).inner(left_symbol, right_symbol, LazyArray(name='B', shape=(3, 4)))

    expected_tree = ast.Node((ast.NodeSymbol.DOT, LazyArray.OPPERATION_MAP[left_symbol], LazyArray.OPPERATION_MAP[right_symbol]), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('B',), ())))
    expected_symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None),
        'B': ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None),
    }
    expected_context = ast.create_context(ast=expected_tree, symbol_table=expected_symbol_table)

    testing.assert_context_equal(expected_context, expression.context)
示例#14
0
def test_reduce_psi_transposev():
    symbol_table = {
        '_a0': ast.SymbolNode(ast.NodeSymbol.ARRAY, (4, ), None, (3, 2, 1, 1)),
        '_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (4, ), None, (4, 2, 3, 1)),
        '_a2': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3, 5, 7), None, None),
    }
    tree = ast.Node((ast.NodeSymbol.PSI, ), (), (),
                    (ast.Node((ast.NodeSymbol.ARRAY, ), (4, ), ('_a0', ), ()),
                     ast.Node((ast.NodeSymbol.TRANSPOSEV, ), (7, 3, 5, 2), (),
                              (ast.Node((ast.NodeSymbol.ARRAY, ), (4, ),
                                        ('_a1', ), ()),
                               ast.Node((ast.NodeSymbol.ARRAY, ), (2, 3, 5, 7),
                                        ('_a2', ), ())))))

    expected_symbol_table = {
        **symbol_table, '_a3':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (4, ), None, (1, 2, 1, 3))
    }
    expected_tree = ast.Node((ast.NodeSymbol.PSI, ), (), (), (ast.Node(
        (ast.NodeSymbol.ARRAY, ), (4, ), ('_a3', ),
        ()), ast.Node((ast.NodeSymbol.ARRAY, ), (2, 3, 5, 7), ('_a2', ), ())))

    testing.assert_transformation(tree, symbol_table, expected_tree,
                                  expected_symbol_table,
                                  dnf._reduce_psi_transposev)
示例#15
0
def test_shape_unit_reduce_plus_minus_multiply_divide_no_symbol(operation):
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (1, 2, 3), None, None),
    }
    tree = ast.Node((ast.NodeSymbol.REDUCE, operation), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),))
    expected_tree = ast.Node((ast.NodeSymbol.REDUCE, operation), (2, 3), (), (
        ast.Node((ast.NodeSymbol.ARRAY,), (1, 2, 3), ('A',), ()),))
    context = ast.create_context(ast=tree, symbol_table=symbol_table)
    expected_context = ast.create_context(ast=expected_tree, symbol_table=symbol_table)
    context_copy = copy.deepcopy(context)

    new_context = shape.calculate_shapes(context)
    testing.assert_context_equal(context, context_copy)
    testing.assert_context_equal(expected_context, new_context)
示例#16
0
def test_matches_rule_simple():
    tree = ast.Node((ast.NodeSymbol.ARRAY, ), (2, 3), ('A', ), ())
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None,
                            (1, 2, 3, 4, 5, 6))
    }
    context = ast.create_context(ast=tree, symbol_table=symbol_table)

    rule = ((ast.NodeSymbol.ARRAY, ), )
    assert dnf.matches_rule(rule, context)
示例#17
0
def test_add_indexing_node():
    tree = ast.Node((ast.NodeSymbol.ARRAY, ), (2, 3), ('A', ), ())
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None,
                            (1, 2, 3, 4, 5, 6))
    }

    expected_tree = ast.Node((ast.NodeSymbol.PSI, ), (2, 3), (), (ast.Node(
        (ast.NodeSymbol.ARRAY, ), (2, ), ('_a3', ),
        ()), ast.Node((ast.NodeSymbol.ARRAY, ), (2, 3), ('A', ), ())))
    expected_symbol_table = {
        'A':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, (1, 2, 3, 4, 5, 6)),
        '_i1':
        ast.SymbolNode(ast.NodeSymbol.INDEX, (), None, (0, 2, 1)),
        '_i2':
        ast.SymbolNode(ast.NodeSymbol.INDEX, (), None, (0, 3, 1)),
        '_a3':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, ), None, (ast.Node(
            (ast.NodeSymbol.ARRAY, ), (), ('_i1', ),
            ()), ast.Node((ast.NodeSymbol.ARRAY, ), (), ('_i2', ), ()))),
    }

    testing.assert_transformation(tree, symbol_table, expected_tree,
                                  expected_symbol_table, dnf.add_indexing_node)
示例#18
0
def test_shape_unit_plus_symbolic(left_shape, right_shape):
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, left_shape, None, None),
        'B': ast.SymbolNode(ast.NodeSymbol.ARRAY, right_shape, None, None),
        'n': ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, None),
    }
    tree = ast.Node((ast.NodeSymbol.PLUS,), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('B',), ()),))
    expected_tree = ast.Node((ast.NodeSymbol.PLUS,), (3, 4, 5), (), (
        ast.Node((ast.NodeSymbol.ARRAY,), left_shape, ('A',), ()),
        ast.Node((ast.NodeSymbol.ARRAY,), right_shape, ('B',), ()),))

    context = ast.create_context(ast=tree, symbol_table=symbol_table)
    expected_context = ast.create_context(ast=expected_tree, symbol_table={
        **symbol_table,
        '_a3': ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, (5,)),
    })
    context_copy = copy.deepcopy(context)

    new_context = shape.calculate_shapes(context)
    exclude_condition_node = ast.select_node(new_context, (1,))
    print(new_context.symbol_table)
    testing.assert_context_equal(context, context_copy)
    testing.assert_context_equal(expected_context, exclude_condition_node)
示例#19
0
def test_array_single_array_binary_operation_cast(function, side, operation):
    expression = function()
    if side == 'right':
        tree = ast.Node((operation,), None, (), (
            ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),
            ast.Node((ast.NodeSymbol.ARRAY,), None, ('_a1',), ())))
    else:
        tree = ast.Node((operation,), None, (), (
            ast.Node((ast.NodeSymbol.ARRAY,), None, ('_a1',), ()),
            ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ())))
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, None),
        '_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, (1,))
    }
    context = ast.create_context(ast=tree, symbol_table=symbol_table)

    testing.assert_context_equal(context, expression.context)
示例#20
0
def test_graphviz_visualization():
    symbol_table = {
        '_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (1, ), None, (0, )),
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None),
        'B': ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None)
    }
    tree = ast.Node(
        (ast.NodeSymbol.PSI, ), None, (),
        (ast.Node((ast.NodeSymbol.ARRAY, ), None, ('_a1', ), ()),
         ast.Node((ast.NodeSymbol.TRANSPOSE, ), None, (), (ast.Node(
             (ast.NodeSymbol.PLUS, ), None, (), (
                 ast.Node((ast.NodeSymbol.ARRAY, ), None, ('A', ), ()),
                 ast.Node((ast.NodeSymbol.ARRAY, ), None, ('B', ), ()),
             )), ))))

    context = ast.create_context(ast=tree, symbol_table=symbol_table)
    visualize_ast(context)
示例#21
0
def test_shape_scalar_plus_minus_multiply_divide_no_symbol(operation):
    symbol_table = {
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4, 5), None, None),
        'B': ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, (0,))
    }
    tree = ast.Node((operation,), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('B',), ()),))
    expected_tree = ast.Node((operation,), (3, 4, 5), (), (
        ast.Node((ast.NodeSymbol.ARRAY,), (3, 4, 5), ('A',), ()),
        ast.Node((ast.NodeSymbol.ARRAY,), (), ('B',), ()),))

    context = ast.create_context(ast=tree, symbol_table=symbol_table)
    expected_context = ast.create_context(ast=expected_tree, symbol_table=symbol_table)
    context_copy = copy.deepcopy(context)

    new_context = shape.calculate_shapes(context)
    testing.assert_context_equal(context, context_copy)
    testing.assert_context_equal(expected_context, new_context)
示例#22
0
def test_metric_flops():
    tree = ast.Node((ast.NodeSymbol.PSI,), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY,), None, ('_a1',), ()),
        ast.Node((ast.NodeSymbol.TRANSPOSE,), None, (), (
            ast.Node((ast.NodeSymbol.PLUS,), None, (), (
                ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ()),
                ast.Node((ast.NodeSymbol.ARRAY,), None, ('B',), ()))),))))

    symbol_table = {
        '_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (1,), None, (0,)),
        'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (10, 100), None, None),
        'B': ast.SymbolNode(ast.NodeSymbol.ARRAY, (10, 100), None, None)
    }

    context = ast.create_context(ast=tree, symbol_table=symbol_table)
    context = shape.calculate_shapes(context)

    assert analysis.metric_flops(context) == 1000

    context = dnf.reduce_to_dnf(context)

    assert analysis.metric_flops(context) == 10
示例#23
0
def test_preorder_replacement():
    counter = itertools.count()

    def replacement_function(context):
        if context.ast.shape is None:
            return ast.replace_node_shape(context, (next(counter), ), ())
        return None

    tree = ast.Node((ast.NodeSymbol.CAT, ), None, (), (
        ast.Node((ast.NodeSymbol.TRANSPOSE, ), None, (), (ast.Node(
            (ast.NodeSymbol.ARRAY, ), None, (), ()), )),
        ast.Node((ast.NodeSymbol.PLUS, ), None, (), (
            ast.Node((ast.NodeSymbol.SHAPE, ), None, (), (ast.Node(
                (ast.NodeSymbol.ARRAY, ), None, (), ()), )),
            ast.Node((ast.NodeSymbol.RAV, ), None, (), (ast.Node(
                (ast.NodeSymbol.ARRAY, ), None, (), ()), )),
        )),
    ))
    context = ast.create_context(ast=tree)
    context_copy = copy.deepcopy(context)

    expected_tree = ast.Node((ast.NodeSymbol.CAT, ), (0, ), (), (
        ast.Node((ast.NodeSymbol.TRANSPOSE, ), (1, ), (), (ast.Node(
            (ast.NodeSymbol.ARRAY, ), (2, ), (), ()), )),
        ast.Node((ast.NodeSymbol.PLUS, ), (3, ), (), (
            ast.Node((ast.NodeSymbol.SHAPE, ), (4, ), (), (ast.Node(
                (ast.NodeSymbol.ARRAY, ), (5, ), (), ()), )),
            ast.Node((ast.NodeSymbol.RAV, ), (6, ), (), (ast.Node(
                (ast.NodeSymbol.ARRAY, ), (7, ), (), ()), )),
        )),
    ))
    expected_context = ast.create_context(ast=expected_tree)

    new_context = ast.node_traversal(context,
                                     replacement_function,
                                     traversal='preorder')

    testing.assert_context_equal(context, context_copy)
    testing.assert_context_equal(new_context, expected_context)
示例#24
0
import itertools
import copy

import pytest

from moa import ast
from moa import visualize
from moa import testing


# is node type
@pytest.mark.parametrize('node, result', [
    (ast.Node((ast.NodeSymbol.ARRAY, ), None, (),
              ()), (True, False, False, False)),
    (ast.Node((ast.NodeSymbol.TRANSPOSE, ), None, (),
              (None, )), (False, True, True, False)),
    (ast.Node((ast.NodeSymbol.CAT, ), None, (),
              (None, None)), (False, True, False, True)),
    (ast.Node((ast.NodeSymbol.DOT, ast.NodeSymbol.TIMES), None, (),
              (None, None)), (False, True, False, True)),
])
def test_ast_nodes(node, result):
    context = ast.create_context(ast=node)
    assert result == (ast.is_array(context), ast.is_operation(context),
                      ast.is_unary_operation(context),
                      ast.is_binary_operation(context))


# node selection and replacement
@pytest.mark.parametrize('node, result', [
    (ast.Node((ast.NodeSymbol.ARRAY, ), None, (), ()), 0),
示例#25
0
def test_join_symbol_tables_simple():
    left_tree = ast.Node((ast.NodeSymbol.PLUS, ), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY, ), None, ('A', ), ()),
        ast.Node((ast.NodeSymbol.ARRAY, ), None, ('B', ), ()),
    ))
    left_symbol_table = {
        'A':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None),
        '_a1':
        ast.SymbolNode(ast.NodeSymbol.ARRAY,
                       (1, ast.Node((ast.NodeSymbol.ARRAY, ), (), ('m', ),
                                    ())), None, None),
        'B':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 4), None, None)
    }

    right_tree = ast.Node((ast.NodeSymbol.MINUS, ), None, (), (
        ast.Node((ast.NodeSymbol.ARRAY, ), None, ('A', ), ()),
        ast.Node((ast.NodeSymbol.ARRAY, ), None, ('_a3', ), ()),
    ))
    right_symbol_table = {
        'A':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None),
        '_a3':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (ast.Node(
            (ast.NodeSymbol.ARRAY, ), (), ('_a10', ),
            ()), ast.Node((ast.NodeSymbol.ARRAY, ), (), ('m', ), ())), None,
                       None),
        '_a10':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, (1, )),
        'm':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, None),
        'B':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 4), None, None),
        'n':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, None),
    }
    left_context = ast.create_context(ast=left_tree,
                                      symbol_table=left_symbol_table)
    right_context = ast.create_context(ast=right_tree,
                                       symbol_table=right_symbol_table)

    new_symbol_table, new_left_context, new_right_context = ast.join_symbol_tables(
        left_context, right_context)

    assert new_symbol_table == {
        'A':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None),
        'B':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 4), None, None),
        'm':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, None),
        '_a0':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, (1, )),
        '_a1':
        ast.SymbolNode(ast.NodeSymbol.ARRAY, (ast.Node(
            (ast.NodeSymbol.ARRAY, ), (), ('_a0', ),
            ()), ast.Node((ast.NodeSymbol.ARRAY, ), (), ('m', ), ())), None,
                       None),
    }

    expected_left_context = ast.create_context(
        ast=ast.Node((ast.NodeSymbol.PLUS, ), None, (),
                     (
                         ast.Node((ast.NodeSymbol.ARRAY, ), None, ('A', ), ()),
                         ast.Node((ast.NodeSymbol.ARRAY, ), None, ('B', ), ()),
                     )),
        symbol_table={
            'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None),
            'B': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 4), None, None)
        })

    expected_right_context = ast.create_context(
        ast=ast.Node(
            (ast.NodeSymbol.MINUS, ), None, (), (
                ast.Node((ast.NodeSymbol.ARRAY, ), None, ('A', ), ()),
                ast.Node((ast.NodeSymbol.ARRAY, ), None, ('_a1', ), ()),
            )),
        symbol_table={
            'A':
            ast.SymbolNode(ast.NodeSymbol.ARRAY, (3, 4), None, None),
            'm':
            ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, None),
            '_a0':
            ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, (1, )),
            '_a1':
            ast.SymbolNode(ast.NodeSymbol.ARRAY, (ast.Node(
                (ast.NodeSymbol.ARRAY, ), (), ('_a0', ),
                ()), ast.Node((ast.NodeSymbol.ARRAY, ), (), ('m', ), ())),
                           None, None)
        })

    testing.assert_context_equal(new_left_context, expected_left_context)
    testing.assert_context_equal(new_right_context, expected_right_context)
示例#26
0
def test_dimension_array():
    symbol_table = {'_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, (3,))}
    node = ast.Node((ast.NodeSymbol.ARRAY,), None, ('_a1',), ())
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    assert shape.dimension(context) == 0
示例#27
0
def test_is_not_vector_2d(): # 2D array
    symbol_table = {'A': ast.SymbolNode(ast.NodeSymbol.ARRAY, (5, 1), None, (1, 2, 3, 4, 5))}
    node = ast.Node((ast.NodeSymbol.ARRAY,), None, ('A',), ())
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    assert not shape.is_vector(context)
示例#28
0
def test_is_not_vector_1d(): # scalar
    symbol_table = {'asdf': ast.SymbolNode(ast.NodeSymbol.ARRAY, (), None, (3,))}
    node = ast.Node((ast.NodeSymbol.ARRAY,), None, ('asdf',), ())
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    assert not shape.is_vector(context)
示例#29
0
def test_is_vector():
    symbol_table = {'_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (5,), None, (1, 2, 3, 4, 5))}
    node = ast.Node((ast.NodeSymbol.ARRAY,), None, ('_a1',), ())
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    assert shape.is_vector(context)
示例#30
0
def test_is_not_scalar_2d(): # 2D array
    symbol_table = {'_a1': ast.SymbolNode(ast.NodeSymbol.ARRAY, (2, 3), None, (1, 2, 3, 4, 5, 6))}
    node = ast.Node((ast.NodeSymbol.ARRAY,), None, ('_a1',), ())
    context = ast.create_context(ast=node, symbol_table=symbol_table)

    assert not shape.is_scalar(context)