def test_vector_assign() -> None:
    result = npir_gen.NpirGen().visit(
        VectorAssignFactory(
            left__name="a",
            right__name="b",
        ))
    assert result == "a_[i:I, j:J, k_:(k_ + 1)] = b_[i:I, j:J, k_:(k_ + 1)]"
def test_vector_unary_not() -> None:
    result = npir_gen.NpirGen().visit(
        npir.VectorUnaryOp(
            expr=FieldSliceFactory(name="a"),
            op=common.UnaryOperator.NOT,
        ))
    assert result == "(np.bitwise_not(a_[i:I, j:J, k_:(k_ + 1)]))"
def test_broadcast_literal(defined_dtype: common.DataType) -> None:
    result = npir_gen.NpirGen().visit(
        npir.BroadCast(expr=npir.Literal(dtype=defined_dtype, value="42")))
    print(result)
    match = re.match(r"np.(\w*?)\(42\)", result)
    assert match
    assert match.groups()[0] == defined_dtype.name.lower()
def test_vector_arithmetic() -> None:
    result = npir_gen.NpirGen().visit(
        npir.VectorArithmetic(
            left=FieldSliceFactory(name="a"),
            right=FieldSliceFactory(name="b"),
            op=common.ArithmeticOperator.ADD,
        ))
    assert result == "(a_[i:I, j:J, k_:(k_ + 1)] + b_[i:I, j:J, k_:(k_ + 1)])"
def test_mask_block_broadcast() -> None:
    result = npir_gen.NpirGen().visit(
        npir.MaskBlock(
            body=[],
            mask=npir.BroadCast(expr=npir.Literal(
                dtype=common.DataType.BOOL, value=common.BuiltInLiteral.TRUE)),
            mask_name="mask1",
        ))
    assert result == "mask1_ = np.full((I - i, J - j, K - k), np.bool(True))\n"
def test_horizontal_block() -> None:
    result = npir_gen.NpirGen().visit(npir.HorizontalBlock(body=[]))
    print(result)
    match = re.match(
        r"(#.*?\n)?i, I = _di_ - 0, _dI_ \+ 0\nj, J = _dj_ - 0, _dJ_ \+ 0\n",
        result,
        re.MULTILINE,
    )
    assert match
def test_temp_with_extent_definition() -> None:
    result = npir_gen.NpirGen().visit(
        VectorAssignFactory(temp_init=True, temp_name="a"),
        field_extents={"a": Extent((0, 1), (-2, 3))},
    )
    assert (
        result ==
        "a_ = ShimmedView(np.zeros((_dI_ + 1, _dJ_ + 5, _dK_), dtype=np.int64), [0, 2, 0])"
    )
def test_native_function() -> None:
    result = npir_gen.NpirGen().visit(
        NativeFuncCallFactory(
            func=common.NativeFunction.MIN,
            args=[
                FieldSliceFactory(name="a"),
                FieldSliceFactory(name="b"),
            ],
        ))
    assert result == "np.minimum(a_[i:I, j:J, k_:(k_ + 1)], b_[i:I, j:J, k_:(k_ + 1)])"
def test_vertical_pass_par() -> None:
    result = npir_gen.NpirGen().visit(
        VerticalPassFactory(body=[], temp_defs=[]))
    print(result)
    match = re.match(
        (r"(#.*?\n)?"
         r"k, K = _dk_, _dK_\n"),
        result,
        re.MULTILINE,
    )
    assert match
def test_cast(defined_dtype: common.DataType,
              other_dtype: common.DataType) -> None:
    result = npir_gen.NpirGen().visit(
        npir.Cast(dtype=other_dtype,
                  expr=npir.Literal(dtype=defined_dtype, value="42")))
    print(result)
    match = re.match(r"^np.(\w*?)\(np.(\w*)\(42\), dtype=np.(\w*)\)", result)
    assert match
    assert match.groups()[0] == "array"
    assert match.groups()[1] == defined_dtype.name.lower()
    assert match.groups()[2] == other_dtype.name.lower()
def test_mask_block_other() -> None:
    result = npir_gen.NpirGen().visit(
        npir.MaskBlock(
            body=[],
            mask=npir.VectorLogic(
                op=common.LogicalOperator.AND,
                left=FieldSliceFactory(name="a"),
                right=FieldSliceFactory(name="b"),
            ),
            mask_name="mask1",
        ))
    assert result.startswith("mask1_ = np.bitwise_and(a_[i:I")
def test_verticall_pass_start_start_forward() -> None:
    result = npir_gen.NpirGen().visit(
        VerticalPassFactory(
            body=[],
            temp_defs=[],
            upper=common.AxisBound.from_start(offset=5),
            direction=common.LoopOrder.FORWARD,
        ))
    print(result)
    match = re.match(
        r"(#.*?\n)?k, K = _dk_, _dk_ \+ 5\nfor k_ in range\(k, K\):\n",
        result,
        re.MULTILINE,
    )
    assert match
def test_verticall_pass_end_end_backward() -> None:
    result = npir_gen.NpirGen().visit(
        VerticalPassFactory(
            body=[],
            temp_defs=[],
            lower=common.AxisBound.from_end(offset=-4),
            upper=common.AxisBound.from_end(offset=-1),
            direction=common.LoopOrder.BACKWARD,
        ))
    print(result)
    match = re.match(
        r"(#.*?\n)?k, K = _dK_ \- 4, _dK_ \- 1\nfor k_ in range\(K-1, k-1, -1\):\n",
        result,
        re.MULTILINE,
    )
    assert match
def test_vertical_pass_temp_def() -> None:
    result = npir_gen.NpirGen().visit(
        VerticalPassFactory(
            temp_defs=[
                VectorAssignFactory(temp_init=True, temp_name="a"),
            ],
            body=[],
            lower=common.AxisBound.from_end(offset=-4),
            upper=common.AxisBound.from_end(offset=-1),
            direction=common.LoopOrder.BACKWARD,
        ))
    print(result)
    match = re.match(
        r"(#.*?\n)?a_ = ShimmedView\(np.zeros\(_domain_, dtype=np.int64\), \[0, 0, 0\]\)\nk, K = _dK_ \- 4, _dK_ \- 1\nfor k_ in range\(K-1, k-1, -1\):\n",
        result,
        re.MULTILINE,
    )
    assert match
def test_vertical_pass_seq() -> None:
    result = npir_gen.NpirGen().visit(
        VerticalPassFactory(
            temp_defs=[],
            body=[],
            lower=common.AxisBound.from_start(offset=1),
            upper=common.AxisBound.from_end(offset=-2),
            direction=common.LoopOrder.FORWARD,
        ))
    print(result)
    match = re.match(
        (r"(#.*?\n)?"
         r"k, K = _dk_ \+ 1, _dK_ - 2\n"
         r"for k_ in range\(k, K\):\n"),
        result,
        re.MULTILINE,
    )
    assert match
def test_computation() -> None:
    result = npir_gen.NpirGen().visit(
        npir.Computation(
            params=[],
            field_params=[],
            field_decls=[],
            vertical_passes=[],
        ),
        field_extents={},
    )
    print(result)
    match = re.match(
        (r"import numpy as np\n\n\n"
         r"def run\(\*, _domain_, _origin_\):\n"
         r"\n?"
         r"(    .*?\n)*"),
        result,
        re.MULTILINE,
    )
    assert match
def test_mask_block_slice_mask() -> None:
    result = npir_gen.NpirGen().visit(
        npir.MaskBlock(body=[],
                       mask=FieldSliceFactory(name="mask1"),
                       mask_name="mask1"))
    assert result == ""
def test_datatype() -> None:
    result = npir_gen.NpirGen().visit(common.DataType.FLOAT64)
    print(result)
    match = re.match(r"np.float64", result)
    assert match
def test_temp_definition() -> None:
    result = npir_gen.NpirGen().visit(
        VectorAssignFactory(temp_init=True, temp_name="a"))
    assert result == "a_ = ShimmedView(np.zeros(_domain_, dtype=np.int64), [0, 0, 0])"
def test_field_slice_sequential_k() -> None:
    result = npir_gen.NpirGen().visit(
        FieldSliceFactory(name="a_field", parallel_k=False,
                          offsets=(-1, 0, 4)))
    assert result == "a_field_[(i - 1):(I - 1), j:J, (k_ + 4):(k_ + 4 + 1)]"
def test_field_slice_parallel_k() -> None:
    result = npir_gen.NpirGen().visit(
        FieldSliceFactory(name="another_field",
                          parallel_k=True,
                          offsets=(0, 0, -3)))
    assert result == "another_field_[i:I, j:J, (k - 3):(K - 3)]"
def test_sequential_offset() -> None:
    result = npir_gen.NpirGen().visit(npir.AxisOffset.k(5))
    print(result)
    assert result == "(k_ + 5):(k_ + 5 + 1)"
def test_sequential_offset_zero() -> None:
    result = npir_gen.NpirGen().visit(npir.AxisOffset.k(0))
    print(result)
    assert result == "k_:(k_ + 1)"
def test_numerical_offset_zero() -> None:
    result = npir_gen.NpirGen().visit(npir.NumericalOffset(value=0))
    assert result == ""
def test_parallel_offset() -> None:
    result = npir_gen.NpirGen().visit(npir.AxisOffset.i(-3))
    print(result)
    assert result == "(i - 3):(I - 3)"
def test_parallel_offset_zero() -> None:
    result = npir_gen.NpirGen().visit(npir.AxisOffset.j(0))
    print(result)
    assert result == "j:J"