Exemplo n.º 1
0
def get_selector_from_legacy_operator_selection_list(
    selected_op_list_path: str, ) -> Any:
    with open(selected_op_list_path, "r") as f:
        # strip out the overload part
        # It's only for legacy config - do NOT copy this code!
        selected_op_list = {
            opname.split(".", 1)[0]
            for opname in yaml.load(f, Loader=YamlLoader)
        }

    # Internal build doesn't use this flag any more. Only used by OSS
    # build now. Every operator should be considered a root operator
    # (hence generating unboxing code for it, which is consistent with
    # the current behaviour), and also be considered as used for
    # training, since OSS doesn't support training on mobile for now.
    #
    is_root_operator = True
    is_used_for_training = True

    from torchgen.selective_build.selector import SelectiveBuilder

    selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
        selected_op_list,
        is_root_operator,
        is_used_for_training,
    )

    return selector
Exemplo n.º 2
0
    def test_selector_factory(self):
        yaml_config_v1 = """
debug_info:
  - model1@v100
  - model2@v51
operators:
  aten::add:
    is_used_for_training: No
    is_root_operator: Yes
    include_all_overloads: Yes
  aten::add.int:
    is_used_for_training: Yes
    is_root_operator: No
    include_all_overloads: No
  aten::mul.int:
    is_used_for_training: Yes
    is_root_operator: No
    include_all_overloads: No
"""

        yaml_config_v2 = """
debug_info:
  - model1@v100
  - model2@v51
operators:
  aten::sub:
    is_used_for_training: No
    is_root_operator: Yes
    include_all_overloads: No
    debug_info:
      - model1@v100
  aten::sub.int:
    is_used_for_training: Yes
    is_root_operator: No
    include_all_overloads: No
"""

        yaml_config_all = "include_all_operators: Yes"

        yaml_config_invalid = "invalid:"

        selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1)

        self.assertTrue(selector1.is_operator_selected("aten::add"))
        self.assertTrue(selector1.is_operator_selected("aten::add.int"))
        # Overload name is not used for checking in v1.
        self.assertTrue(selector1.is_operator_selected("aten::add.float"))

        def gen():
            return SelectiveBuilder.from_yaml_str(yaml_config_invalid)

        self.assertRaises(Exception, gen)

        selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all)

        self.assertTrue(selector_all.is_operator_selected("aten::add"))
        self.assertTrue(selector_all.is_operator_selected("aten::sub"))
        self.assertTrue(selector_all.is_operator_selected("aten::sub.int"))
        self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32"))

        selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2)

        self.assertFalse(selector2.is_operator_selected("aten::add"))
        self.assertTrue(selector2.is_operator_selected("aten::sub"))
        self.assertTrue(selector2.is_operator_selected("aten::sub.int"))

        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
            ["aten::add", "aten::add.int", "aten::mul.int"],
            False,
            False,
        )
        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float"))
        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add"))
        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int"))
        self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub"))

        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
        self.assertFalse(
            selector_legacy_v1.is_operator_selected_for_training("aten::add")
        )

        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
            ["aten::add", "aten::add.int", "aten::mul.int"],
            True,
            False,
        )

        self.assertTrue(selector_legacy_v1.is_root_operator("aten::add"))
        self.assertFalse(
            selector_legacy_v1.is_operator_selected_for_training("aten::add")
        )
        self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float"))
        self.assertFalse(
            selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
        )

        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
            ["aten::add", "aten::add.int", "aten::mul.int"],
            False,
            True,
        )

        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
        self.assertTrue(
            selector_legacy_v1.is_operator_selected_for_training("aten::add")
        )
        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float"))
        self.assertTrue(
            selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
        )