Beispiel #1
0
def build_engine(batch_sizes, workspace_size, sequence_length, config, weights_dict, squad_json, vocab_file, calibrationCacheFile, calib_num):
    explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(explicit_batch_flag) as network, builder.create_builder_config() as builder_config:
        builder_config.max_workspace_size = workspace_size * (1024 * 1024)
        if config.use_fp16:
            builder_config.set_flag(trt.BuilderFlag.FP16)
        if config.use_int8:
            builder_config.set_flag(trt.BuilderFlag.INT8)
            if not config.use_qat:
                calibrator = BertCalibrator(squad_json, vocab_file, calibrationCacheFile, 1, sequence_length, calib_num)
                builder_config.set_quantization_flag(trt.QuantizationFlag.CALIBRATE_BEFORE_FUSION)
                builder_config.int8_calibrator = calibrator
        if config.use_strict:
            builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)

        # Create the network
        emb_layer = emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_length, batch_sizes)
        embeddings = emb_layer.get_output(0)
        mask_idx = emb_layer.get_output(1)

        bert_out = bert_model(config, weights_dict, network, embeddings, mask_idx)

        squad_logits = squad_output("cls_", config, weights_dict, network, bert_out)
        squad_logits_out = squad_logits.get_output(0)

        network.mark_output(squad_logits_out)

        build_start_time = time.time()
        engine = builder.build_engine(network, builder_config)
        build_time_elapsed = (time.time() - build_start_time)
        TRT_LOGGER.log(TRT_LOGGER.INFO, "build engine in {:.3f} Sec".format(build_time_elapsed))
        if config.use_int8 and not config.use_qat:
            calibrator.free()
        return engine
Beispiel #2
0
def build_engine(batch_sizes, workspace_size, sequence_lengths, config,
                 weights_dict, squad_json, vocab_file, calibrationCacheFile,
                 calib_num):
    explicit_batch_flag = 1 << int(
        trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

    with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
            explicit_batch_flag) as network, builder.create_builder_config(
            ) as builder_config:
        builder_config.max_workspace_size = workspace_size * (1024 * 1024)
        if config.use_fp16:
            builder_config.set_flag(trt.BuilderFlag.FP16)
        if config.use_int8:
            builder_config.set_flag(trt.BuilderFlag.INT8)
            if not config.use_qat:
                calibrator = BertCalibrator(squad_json, vocab_file,
                                            calibrationCacheFile, 1,
                                            sequence_lengths[-1], calib_num)
                builder_config.set_quantization_flag(
                    trt.QuantizationFlag.CALIBRATE_BEFORE_FUSION)
                builder_config.int8_calibrator = calibrator
        if config.use_strict:
            builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)

        if config.use_sparsity:
            TRT_LOGGER.log(TRT_LOGGER.INFO,
                           "Setting sparsity flag on builder_config.")
            builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)

        # speed up the engine build for trt major version >= 8
        # 1. disable cudnn tactic
        # 2. load global timing cache
        if trt_version[0] >= 8:
            tactic_source = 1 << int(trt.TacticSource.CUBLAS) | 1 << int(
                trt.TacticSource.CUBLAS_LT)
            builder_config.set_tactic_sources(tactic_source)
            if config.timing_cache != None:
                if os.path.exists(config.timing_cache):
                    with open(config.timing_cache, "rb") as f:
                        cache = builder_config.create_timing_cache(f.read())
                        builder_config.set_timing_cache(cache,
                                                        ignore_mismatch=False)
                else:
                    cache = builder_config.create_timing_cache(b"")
                    builder_config.set_timing_cache(cache,
                                                    ignore_mismatch=False)

        # only use the largest sequence when in calibration mode
        if config.is_calib_mode:
            sequence_lengths = sequence_lengths[-1:]

        # Create the network
        emb_layer = emb_layernorm(builder, network, config, weights_dict,
                                  builder_config, sequence_lengths,
                                  batch_sizes)
        embeddings = emb_layer.get_output(0)
        mask_idx = emb_layer.get_output(1)

        bert_out = bert_model(config, weights_dict, network, embeddings,
                              mask_idx)

        squad_logits = squad_output("cls_", config, weights_dict, network,
                                    bert_out)
        squad_logits_out = squad_logits.get_output(0)

        network.mark_output(squad_logits_out)

        build_start_time = time.time()
        engine = builder.build_engine(network, builder_config)
        build_time_elapsed = (time.time() - build_start_time)
        TRT_LOGGER.log(TRT_LOGGER.INFO,
                       "build engine in {:.3f} Sec".format(build_time_elapsed))

        # save global timing cache
        if trt_version[0] >= 8 and config.timing_cache != None:
            cache = builder_config.get_timing_cache()
            with cache.serialize() as buffer:
                with open(config.timing_cache, "wb") as f:
                    f.write(buffer)
                    f.flush()
                    os.fsync(f)

        if config.use_int8 and not config.use_qat:
            calibrator.free()
        return engine