示例#1
0
def _generate_provider_registrations(
        ort_root: Path, build_dir: Path, use_cuda: bool,
        required_ops: typing.Optional[dict],
        op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]):
    '''Generate provider registration files.'''
    kernel_registration_files = [
        Path(f) for f in op_registration_utils.get_kernel_registration_files(
            str(ort_root), use_cuda)
    ]

    for kernel_registration_file in kernel_registration_files:
        if not kernel_registration_file.is_file():
            raise ValueError(
                f'Kernel registration file does not exist: {kernel_registration_file}'
            )

        log.info("Processing {}".format(kernel_registration_file))

        reduced_path = _get_op_reduction_file_path(ort_root, build_dir,
                                                   kernel_registration_file)

        reduced_path.parent.mkdir(parents=True, exist_ok=True)

        # read from original and create the reduced kernel def file with commented out lines for any kernels that are
        # not required
        with open(reduced_path, 'w') as file_to_write:
            processor = _ExcludingRegistrationProcessor(
                required_ops, op_type_impl_filter, file_to_write)

            op_registration_utils.process_kernel_registration_file(
                kernel_registration_file, processor)

            if not processor.ok():
                # error should have already been logged so just exit
                sys.exit(-1)
示例#2
0
def _process_provider_registrations(
        ort_root: str, use_cuda: bool,
        required_ops: typing.Optional[dict],
        op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]):
    '''Rewrite provider registration files.'''
    kernel_registration_files = op_registration_utils.get_kernel_registration_files(ort_root, use_cuda)

    for kernel_registration_file in kernel_registration_files:
        if not os.path.isfile(kernel_registration_file):
            raise ValueError('Kernel registration file {} does not exist'.format(kernel_registration_file))

        log.info("Processing {}".format(kernel_registration_file))

        old_path = Path(kernel_registration_file)
        reduced_path = Path(old_path.parent, f'{old_path.stem}{REDUCED_KERNEL_DEF_SUFFIX}{old_path.suffix}')

        # read from original and create the reduced kernel def file (*_reduced_ops.cc),
        # with commented out lines for any kernels that are not required
        with open(reduced_path, 'w') as file_to_write:
            processor = _ExcludingRegistrationProcessor(required_ops, op_type_impl_filter, file_to_write)

            op_registration_utils.process_kernel_registration_file(kernel_registration_file, processor)

            if not processor.ok():
                # error should have already been logged so just exit
                sys.exit(-1)

        # enable the contents in the *_reduced_ops.cc
        with open(reduced_path, 'r+') as file:
            file_content = file.read().replace(r'#ifndef REDUCED_OPS_BUILD', r'#ifdef REDUCED_OPS_BUILD')

        with open(reduced_path, "w") as file_to_write:
            file_to_write.write(file_content)
示例#3
0
def _process_provider_registrations(
        ort_root: str, use_cuda: bool, required_ops: dict,
        op_type_usage_manager: typing.Optional[OperatorTypeUsageManager],
        globally_allowed_types: typing.Optional[typing.Set[str]]):
    '''Rewrite provider registration files.'''
    kernel_registration_files = op_registration_utils.get_kernel_registration_files(
        ort_root, use_cuda)

    for kernel_registration_file in kernel_registration_files:
        if not os.path.isfile(kernel_registration_file):
            raise ValueError(
                'Kernel registration file {} does not exist'.format(
                    kernel_registration_file))

        log.info("Processing {}".format(kernel_registration_file))

        backup_path = kernel_registration_file + '~'
        shutil.move(kernel_registration_file, backup_path)

        # read from backup and overwrite original with commented out lines for any kernels that are not required
        with open(kernel_registration_file, 'w') as file_to_write:
            processor = _ExcludingRegistrationProcessor(
                required_ops, op_type_usage_manager, globally_allowed_types,
                file_to_write)

            op_registration_utils.process_kernel_registration_file(
                backup_path, processor)

            if not processor.ok():
                # error should have already been logged so just exit
                sys.exit(-1)
示例#4
0
def exclude_unused_ops_and_types(config_path, enable_type_reduction=False, use_cuda=True):
    required_ops, op_type_usage_manager = parse_config(config_path, enable_type_reduction)

    registration_files = op_registration_utils.get_kernel_registration_files(ort_root, use_cuda)

    _exclude_unused_ops_and_types_in_registrations(required_ops, op_type_usage_manager, registration_files)

    _generate_required_types_cpp_code(ort_root, op_type_usage_manager)
def exclude_unused_ops(models_path, config_path, ort_root=None, use_cuda=True, output_config_path=None):
    '''Determine operators that are used, and either exclude them or create a configuration file that will.
    Note that this called directly from build.py'''

    if not models_path and not config_path:
        log.error('Please specify model_path and/or config_path.')
        sys.exit(-1)

    if not ort_root and not output_config_path:
        log.info('ort_root was not specified. Inferring ONNX Runtime repository root from location of this script.')

    required_ops = _extract_ops_from_config(config_path, _extract_ops_from_model(models_path, {}))

    if output_config_path:
        _create_config_file_with_required_ops(required_ops, models_path, config_path, output_config_path)
    else:
        registration_files = op_registration_utils.get_kernel_registration_files(ort_root, use_cuda)
        _exclude_unused_ops_in_registrations(required_ops, registration_files)
if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Script to validate operator kernel registrations.")

    parser.add_argument(
        "--ort_root",
        type=str,
        help="Path to ONNXRuntime repository root. "
        "Inferred from the location of this script if not provided.",
    )

    args = parser.parse_args()

    ort_root = os.path.abspath(args.ort_root) if args.ort_root else ""
    include_cuda = True  # validate CPU and CUDA EP registrations

    registration_files = op_registration_utils.get_kernel_registration_files(
        ort_root, include_cuda)

    for file in registration_files:
        log.info("Processing {}".format(file))

        processor = RegistrationValidator()
        op_registration_utils.process_kernel_registration_file(file, processor)
        processor.validate_last_registrations()

        if not processor.ok():
            sys.exit(-1)