Ejemplo n.º 1
0
    def is_operator_selected_for_training(self, name: str) -> bool:
        if not self.is_operator_selected(name):
            return False
        if self.include_all_operators:
            return True

        not_training_op = SelectiveBuildOperator(
            name='',
            is_root_operator=False,
            is_used_for_training=False,
            include_all_overloads=False,
            _debug_info=None,
        )
        op = not_training_op
        if name in self.operators:
            op = self.operators[name]

        name = strip_operator_overload_name(name)
        base_op = not_training_op
        if name in self.operators:
            base_op = self.operators[name]

        return (
            op.is_used_for_training or
            (base_op.include_all_overloads and base_op.is_used_for_training)
        )
Ejemplo n.º 2
0
    def from_yaml_dict(data: Dict[str, object]) -> 'SelectiveBuilder':
        valid_top_level_keys = {
            'include_all_non_op_selectives',
            'include_all_operators',
            'debug_info',
            'operators',
            'kernel_metadata',
            'custom_classes',
            'build_features',
        }
        top_level_keys = set(data.keys())
        if len(top_level_keys - valid_top_level_keys) > 0:
            raise Exception("Got unexpected top level keys: {}".format(
                ",".join(top_level_keys - valid_top_level_keys),
            ))
        include_all_operators = data.get('include_all_operators', False)
        assert isinstance(include_all_operators, bool)

        debug_info = None
        if 'debug_info' in data:
            di_list = data['debug_info']
            assert isinstance(di_list, list)

            debug_info = tuple(map(lambda x: str(x), di_list))

        operators = {}
        operators_dict = data.get('operators', {})
        assert isinstance(operators_dict, dict)

        for (k, v) in operators_dict.items():
            operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v)

        kernel_metadata = {}
        kernel_metadata_dict = data.get('kernel_metadata', {})
        assert isinstance(kernel_metadata_dict, dict)

        for (k, v) in kernel_metadata_dict.items():
            kernel_metadata[str(k)] = list(map(lambda dtype: str(dtype), v))

        custom_classes = data.get('custom_classes', [])
        custom_classes = set(custom_classes)  # type: ignore[arg-type]

        build_features = data.get('build_features', [])
        build_features = set(build_features)  # type: ignore[arg-type]

        include_all_non_op_selectives = data.get('include_all_non_op_selectives', False)
        assert isinstance(include_all_non_op_selectives, bool)

        return SelectiveBuilder(
            include_all_operators,
            debug_info,
            operators,
            kernel_metadata,
            custom_classes,  # type: ignore[arg-type]
            build_features,  # type: ignore[arg-type]
            include_all_non_op_selectives,
        )