示例#1
0
def _generate_provider_registrations(
        ort_root: Path, build_dir: Path, use_cuda: bool,
        required_ops: typing.Optional[dict],
        op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]):
    '''Generate provider registration files.'''
    kernel_registration_files = [
        Path(f) for f in op_registration_utils.get_kernel_registration_files(
            str(ort_root), use_cuda)
    ]

    for kernel_registration_file in kernel_registration_files:
        if not kernel_registration_file.is_file():
            raise ValueError(
                f'Kernel registration file does not exist: {kernel_registration_file}'
            )

        log.info("Processing {}".format(kernel_registration_file))

        reduced_path = _get_op_reduction_file_path(ort_root, build_dir,
                                                   kernel_registration_file)

        reduced_path.parent.mkdir(parents=True, exist_ok=True)

        # read from original and create the reduced kernel def file with commented out lines for any kernels that are
        # not required
        with open(reduced_path, 'w') as file_to_write:
            processor = _ExcludingRegistrationProcessor(
                required_ops, op_type_impl_filter, file_to_write)

            op_registration_utils.process_kernel_registration_file(
                kernel_registration_file, processor)

            if not processor.ok():
                # error should have already been logged so just exit
                sys.exit(-1)
示例#2
0
def _exclude_unused_ops_and_types_in_registrations(
        required_operators, op_type_usage_manager,
        provider_registration_paths):
    '''rewrite provider registration file to exclude unused ops'''

    for kernel_registration_file in provider_registration_paths:
        if not os.path.isfile(kernel_registration_file):
            raise ValueError(
                'Kernel registration file {} does not exist'.format(
                    kernel_registration_file))

        log.info("Processing {}".format(kernel_registration_file))

        backup_path = kernel_registration_file + '~'
        shutil.move(kernel_registration_file, backup_path)

        # read from backup and overwrite original with commented out lines for any kernels that are not required
        with open(kernel_registration_file, 'w') as file_to_write:
            processor = ExcludeOpsAndTypesRegistrationProcessor(
                required_operators, op_type_usage_manager, file_to_write)

            op_registration_utils.process_kernel_registration_file(
                backup_path, processor)

            if not processor.ok():
                # error should have already been logged so just exit
                sys.exit(-1)
示例#3
0
def _process_provider_registrations(
        ort_root: str, use_cuda: bool,
        required_ops: typing.Optional[dict],
        op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]):
    '''Rewrite provider registration files.'''
    kernel_registration_files = op_registration_utils.get_kernel_registration_files(ort_root, use_cuda)

    for kernel_registration_file in kernel_registration_files:
        if not os.path.isfile(kernel_registration_file):
            raise ValueError('Kernel registration file {} does not exist'.format(kernel_registration_file))

        log.info("Processing {}".format(kernel_registration_file))

        old_path = Path(kernel_registration_file)
        reduced_path = Path(old_path.parent, f'{old_path.stem}{REDUCED_KERNEL_DEF_SUFFIX}{old_path.suffix}')

        # read from original and create the reduced kernel def file (*_reduced_ops.cc),
        # with commented out lines for any kernels that are not required
        with open(reduced_path, 'w') as file_to_write:
            processor = _ExcludingRegistrationProcessor(required_ops, op_type_impl_filter, file_to_write)

            op_registration_utils.process_kernel_registration_file(kernel_registration_file, processor)

            if not processor.ok():
                # error should have already been logged so just exit
                sys.exit(-1)

        # enable the contents in the *_reduced_ops.cc
        with open(reduced_path, 'r+') as file:
            file_content = file.read().replace(r'#ifndef REDUCED_OPS_BUILD', r'#ifdef REDUCED_OPS_BUILD')

        with open(reduced_path, "w") as file_to_write:
            file_to_write.write(file_content)
示例#4
0
def _process_provider_registrations(
        ort_root: str, use_cuda: bool, required_ops: dict,
        op_type_usage_manager: typing.Optional[OperatorTypeUsageManager],
        globally_allowed_types: typing.Optional[typing.Set[str]]):
    '''Rewrite provider registration files.'''
    kernel_registration_files = op_registration_utils.get_kernel_registration_files(
        ort_root, use_cuda)

    for kernel_registration_file in kernel_registration_files:
        if not os.path.isfile(kernel_registration_file):
            raise ValueError(
                'Kernel registration file {} does not exist'.format(
                    kernel_registration_file))

        log.info("Processing {}".format(kernel_registration_file))

        backup_path = kernel_registration_file + '~'
        shutil.move(kernel_registration_file, backup_path)

        # read from backup and overwrite original with commented out lines for any kernels that are not required
        with open(kernel_registration_file, 'w') as file_to_write:
            processor = _ExcludingRegistrationProcessor(
                required_ops, op_type_usage_manager, globally_allowed_types,
                file_to_write)

            op_registration_utils.process_kernel_registration_file(
                backup_path, processor)

            if not processor.ok():
                # error should have already been logged so just exit
                sys.exit(-1)
if __name__ == "__main__":

    parser = argparse.ArgumentParser(
        description="Script to validate operator kernel registrations.")

    parser.add_argument(
        "--ort_root",
        type=str,
        help="Path to ONNXRuntime repository root. "
        "Inferred from the location of this script if not provided.",
    )

    args = parser.parse_args()

    ort_root = os.path.abspath(args.ort_root) if args.ort_root else ""
    include_cuda = True  # validate CPU and CUDA EP registrations

    registration_files = op_registration_utils.get_kernel_registration_files(
        ort_root, include_cuda)

    for file in registration_files:
        log.info("Processing {}".format(file))

        processor = RegistrationValidator()
        op_registration_utils.process_kernel_registration_file(file, processor)
        processor.validate_last_registrations()

        if not processor.ok():
            sys.exit(-1)