Beispiel #1
0
def test_add_lower_pass_invalid_format():
    # wrong format
    with pytest.raises(TVMCException):
        _ = parse_configs(["tir.add_lower_pass=tir.transform.UnrollLoop,1"])
    # missing pass name
    with pytest.raises(TVMCException):
        _ = parse_configs(["tir.add_lower_pass=1,tir.transform.UnrollLoop,3"])
    # wrong opt level
    with pytest.raises(TVMCException):
        _ = parse_configs(["tir.add_lower_pass=a,tir.transform.UnrollLoop"])
    # fake module
    with pytest.raises(ModuleNotFoundError):
        _ = parse_configs([
            "tir.add_lower_pass=1,tir.transform.UnrollLoop,2,path.to.module.fake_func"
        ])
    # real module and fake func
    with pytest.raises(TVMCException):
        _ = parse_configs([
            "tir.add_lower_pass=1,tir.transform.UnrollLoop,2,tvm.tir.fake_func"
        ])
Beispiel #2
0
def test_config_valid_multiple_configs():
    configs = parse_configs([
        "relay.backend.use_auto_scheduler=false",
        "tir.detect_global_barrier=10",
        "relay.ext.vitis_ai.options.build_dir=mystring",
    ])

    assert len(configs) == 3
    assert "relay.backend.use_auto_scheduler" in configs.keys()
    assert configs["relay.backend.use_auto_scheduler"] == False
    assert "tir.detect_global_barrier" in configs.keys()
    assert configs["tir.detect_global_barrier"] == 10
    assert "relay.ext.vitis_ai.options.build_dir" in configs.keys()
    assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"
Beispiel #3
0
def test_add_lower_pass_multi_built_in_pass():
    configs = parse_configs([
        "tir.add_lower_pass=1,tir.transform.UnrollLoop",
        "tir.add_lower_pass=1,tir.transform.HoistIfThenElse,2,tir.transform.LoopPartition",
    ])

    assert len(configs["tir.add_lower_pass"]) == 3
    # opt_level: 1, pass: tir.transform.UnrollLoop
    assert configs["tir.add_lower_pass"][0][0] == 1
    assert isinstance(configs["tir.add_lower_pass"][0][1], PrimFuncPass)
    # opt_level: 1, pass: tir.transform.HoistIfThenElse
    assert configs["tir.add_lower_pass"][1][0] == 1
    assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass)
    # opt_level: 2, pass: tir.transform.LoopPartition
    assert configs["tir.add_lower_pass"][2][0] == 2
    assert isinstance(configs["tir.add_lower_pass"][2][1], PrimFuncPass)
Beispiel #4
0
def test_add_lower_pass_multi_mix_pass():
    fake_pass_1 = mock.MagicMock()
    fake_pass_2 = mock.MagicMock()
    with mock.patch.dict("sys.modules", {
            "fake_module": fake_pass_1,
            "fake_module": fake_pass_2
    }):
        configs = parse_configs([
            "tir.add_lower_pass=1,fake_module.fake_pass_1,1,tir.transform.UnrollLoop",
            "tir.add_lower_pass=2,fake_module.fake_pass_2,2,tir.transform.LoopPartition",
        ])
        assert len(configs["tir.add_lower_pass"]) == 4
        # opt_level: 1, pass: fake_module.fake_pass_1
        assert configs["tir.add_lower_pass"][0][0] == 1
        # opt_level: 1, pass: tir.transform.UnrollLoop
        assert configs["tir.add_lower_pass"][1][0] == 1
        assert isinstance(configs["tir.add_lower_pass"][1][1], PrimFuncPass)
        # opt_level: 2, pass: fake_module.fake_pass_2
        assert configs["tir.add_lower_pass"][2][0] == 2
        # opt_level: 2, pass: tir.transform.LoopPartition
        assert configs["tir.add_lower_pass"][3][0] == 2
        assert isinstance(configs["tir.add_lower_pass"][3][1], PrimFuncPass)
Beispiel #5
0
def test_add_lower_pass_multi_external_pass():
    fake_pass_1 = mock.MagicMock()
    fake_pass_2 = mock.MagicMock()
    fake_pass_3 = mock.MagicMock()
    with mock.patch.dict(
            "sys.modules",
        {
            "fake_module": fake_pass_1,
            "fake_module": fake_pass_2,
            "fake_module": fake_pass_3
        },
    ):
        configs = parse_configs([
            "tir.add_lower_pass=1,fake_module.fake_pass_1,2,fake_module.fake_pass2",
            "tir.add_lower_pass=3,fake_module.fake_pass_3",
        ])
        assert len(configs["tir.add_lower_pass"]) == 3
        # opt_level: 1, pass: fake_module.fake_pass_1
        assert configs["tir.add_lower_pass"][0][0] == 1
        # opt_level: 2, pass: fake_module.fake_pass_2
        assert configs["tir.add_lower_pass"][1][0] == 2
        # opt_level: 3, pass: fake_module.fake_pass_3
        assert configs["tir.add_lower_pass"][2][0] == 3
Beispiel #6
0
def test_config_valid_config_bool():
    configs = parse_configs(["relay.backend.use_auto_scheduler=true"])

    assert len(configs) == 1
    assert "relay.backend.use_auto_scheduler" in configs.keys()
    assert configs["relay.backend.use_auto_scheduler"] == True
Beispiel #7
0
def test_config_empty():
    with pytest.raises(TVMCException):
        _ = parse_configs([""])
Beispiel #8
0
def test_config_unsupported_tvmc_config():
    with pytest.raises(TVMCException):
        _ = parse_configs(["tir.LoopPartition=value"])
Beispiel #9
0
def test_config_missing_from_tvm():
    with pytest.raises(TVMCException):
        _ = parse_configs(
            ["relay.backend.use_auto_scheduler.missing.value=1234"])
Beispiel #10
0
def test_config_invalid_format():
    with pytest.raises(TVMCException):
        _ = parse_configs(["relay.backend.use_auto_scheduler.missing.value"])