예제 #1
0
    def test_kernel_dtypes(self):
        yaml_config = """
kernel_metadata:
  add_kernel:
    - int8
    - int32
  sub_kernel:
    - int16
    - int32
  add/sub_kernel:
    - float
    - complex
"""

        selector = SelectiveBuilder.from_yaml_str(yaml_config)

        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
        self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
        self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
        self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))

        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
        self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
        self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
예제 #2
0
    def test_merge_kernel_dtypes(self):
        yaml_config1 = """
kernel_metadata:
  add_kernel:
    - int8
  add/sub_kernel:
    - float
    - complex
    - none
  mul_kernel:
    - int8
"""

        yaml_config2 = """
kernel_metadata:
  add_kernel:
    - int32
  sub_kernel:
    - int16
    - int32
  add/sub_kernel:
    - float
    - complex
"""

        selector1 = SelectiveBuilder.from_yaml_str(yaml_config1)
        selector2 = SelectiveBuilder.from_yaml_str(yaml_config2)

        selector = combine_selective_builders(selector1, selector2)

        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
        self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
        self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
        self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))

        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none"))
        self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
        self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))

        self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
        self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
예제 #3
0
    def test_all_kernel_dtypes_selected(self):
        yaml_config = """
include_all_non_op_selectives: True
"""

        selector = SelectiveBuilder.from_yaml_str(yaml_config)

        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16"))
        self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
예제 #4
0
    def test_training_op_fetch(self):
        yaml_config = """
operators:
  aten::add.int:
    is_used_for_training: No
    is_root_operator: Yes
    include_all_overloads: No
  aten::add:
    is_used_for_training: Yes
    is_root_operator: No
    include_all_overloads: Yes
"""

        selector = SelectiveBuilder.from_yaml_str(yaml_config)
        self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
        self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
예제 #5
0
    def test_custom_namespace_selected_correctly(self):
        yaml_config = """
operators:
  aten::add.int:
    is_used_for_training: No
    is_root_operator: Yes
    include_all_overloads: No
  custom::add:
    is_used_for_training: Yes
    is_root_operator: No
    include_all_overloads: Yes
"""
        selector = SelectiveBuilder.from_yaml_str(yaml_config)
        native_function, _ = NativeFunction.from_yaml(
            {"func": "custom::add() -> Tensor"},
            loc=Location(__file__, 1),
            valid_tags=set(),
        )
        self.assertTrue(selector.is_native_function_selected(native_function))
예제 #6
0
 def gen():
     return SelectiveBuilder.from_yaml_str(yaml_config_invalid)
예제 #7
0
    def test_selector_factory(self):
        yaml_config_v1 = """
debug_info:
  - model1@v100
  - model2@v51
operators:
  aten::add:
    is_used_for_training: No
    is_root_operator: Yes
    include_all_overloads: Yes
  aten::add.int:
    is_used_for_training: Yes
    is_root_operator: No
    include_all_overloads: No
  aten::mul.int:
    is_used_for_training: Yes
    is_root_operator: No
    include_all_overloads: No
"""

        yaml_config_v2 = """
debug_info:
  - model1@v100
  - model2@v51
operators:
  aten::sub:
    is_used_for_training: No
    is_root_operator: Yes
    include_all_overloads: No
    debug_info:
      - model1@v100
  aten::sub.int:
    is_used_for_training: Yes
    is_root_operator: No
    include_all_overloads: No
"""

        yaml_config_all = "include_all_operators: Yes"

        yaml_config_invalid = "invalid:"

        selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1)

        self.assertTrue(selector1.is_operator_selected("aten::add"))
        self.assertTrue(selector1.is_operator_selected("aten::add.int"))
        # Overload name is not used for checking in v1.
        self.assertTrue(selector1.is_operator_selected("aten::add.float"))

        def gen():
            return SelectiveBuilder.from_yaml_str(yaml_config_invalid)

        self.assertRaises(Exception, gen)

        selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all)

        self.assertTrue(selector_all.is_operator_selected("aten::add"))
        self.assertTrue(selector_all.is_operator_selected("aten::sub"))
        self.assertTrue(selector_all.is_operator_selected("aten::sub.int"))
        self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32"))

        selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2)

        self.assertFalse(selector2.is_operator_selected("aten::add"))
        self.assertTrue(selector2.is_operator_selected("aten::sub"))
        self.assertTrue(selector2.is_operator_selected("aten::sub.int"))

        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
            ["aten::add", "aten::add.int", "aten::mul.int"],
            False,
            False,
        )
        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float"))
        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add"))
        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int"))
        self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub"))

        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
        self.assertFalse(
            selector_legacy_v1.is_operator_selected_for_training("aten::add")
        )

        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
            ["aten::add", "aten::add.int", "aten::mul.int"],
            True,
            False,
        )

        self.assertTrue(selector_legacy_v1.is_root_operator("aten::add"))
        self.assertFalse(
            selector_legacy_v1.is_operator_selected_for_training("aten::add")
        )
        self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float"))
        self.assertFalse(
            selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
        )

        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
            ["aten::add", "aten::add.int", "aten::mul.int"],
            False,
            True,
        )

        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
        self.assertTrue(
            selector_legacy_v1.is_operator_selected_for_training("aten::add")
        )
        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float"))
        self.assertTrue(
            selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
        )