def test_scalarliteral(defined_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.ScalarLiteral(dtype=defined_dtype, value="42")) print(result) match = re.match(r"(.+?)\(42\)", result) assert match match_dtype(match.groups()[0], defined_dtype)
def test_broadcast_literal(defined_dtype: common.DataType, is_serial: bool) -> None: result = NpirCodegen().visit( npir.Broadcast( expr=npir.ScalarLiteral(dtype=defined_dtype, value="42"))) print(result) match = re.match(r"(.*?)\(42\)", result) assert match match_dtype(match.groups()[0], defined_dtype)
def test_scalar_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.ScalarCast(dtype=other_dtype, expr=npir.ScalarLiteral(dtype=defined_dtype, value="42")) ) print(result) match = re.match(r"np\.(?P<other_dtype>\w*)\(np.(?P<defined_dtype>\w*)\(42\)\)", result) assert match assert match.group("defined_dtype") == defined_dtype.name.lower() assert match.group("other_dtype") == other_dtype.name.lower()
def test_scalar_cast(defined_dtype: common.DataType, other_dtype: common.DataType) -> None: result = NpirCodegen().visit( npir.ScalarCast(dtype=other_dtype, expr=npir.ScalarLiteral(dtype=defined_dtype, value="42"))) print(result) match = re.match(r"(?P<other_dtype>.+)\((?P<defined_dtype>.+)\(42\)\)", result) assert match match_dtype(match.group("defined_dtype"), defined_dtype) match_dtype(match.group("other_dtype"), other_dtype)
def test_broadcast_literal(defined_dtype: common.DataType, is_serial: bool) -> None: result = NpirCodegen().visit( npir.Broadcast(expr=npir.ScalarLiteral(dtype=defined_dtype, value="42")), is_serial=is_serial, lower=(0, 0), upper=(0, 0), ) print(result) match = re.match( r"np\.full\(\(_dI_\s*\+\s*(?P<iext>\d+)\s*,\s*_dJ_\s*\+\s*(?P<jext>\d+)\s*,\s*(?P<kbounds>[^\)]+)\),\s*np\.(?P<dtype>\w+)\(42\)\)", result, ) assert match assert tuple(match.group(ext) for ext in ("iext", "jext")) == ("0", "0") assert match.group("kbounds") == "1" if is_serial else "K - k" assert match.group("dtype") == defined_dtype.name.lower()
def test_scalarliteral(defined_dtype: common.DataType) -> None: result = NpirCodegen().visit(npir.ScalarLiteral(dtype=defined_dtype, value="42")) print(result) match = re.match(r"np.(\w*?)\(42\)", result) assert match assert match.groups()[0] == defined_dtype.name.lower()