def build_engine(batch_sizes, workspace_size, sequence_length, config, weights_dict, squad_json, vocab_file, calibrationCacheFile, calib_num): explicit_batch_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) with trt.Builder(TRT_LOGGER) as builder, builder.create_network(explicit_batch_flag) as network, builder.create_builder_config() as builder_config: builder_config.max_workspace_size = workspace_size * (1024 * 1024) if config.use_fp16: builder_config.set_flag(trt.BuilderFlag.FP16) if config.use_int8: builder_config.set_flag(trt.BuilderFlag.INT8) if not config.use_qat: calibrator = BertCalibrator(squad_json, vocab_file, calibrationCacheFile, 1, sequence_length, calib_num) builder_config.set_quantization_flag(trt.QuantizationFlag.CALIBRATE_BEFORE_FUSION) builder_config.int8_calibrator = calibrator if config.use_strict: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) # Create the network emb_layer = emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_length, batch_sizes) embeddings = emb_layer.get_output(0) mask_idx = emb_layer.get_output(1) bert_out = bert_model(config, weights_dict, network, embeddings, mask_idx) squad_logits = squad_output("cls_", config, weights_dict, network, bert_out) squad_logits_out = squad_logits.get_output(0) network.mark_output(squad_logits_out) build_start_time = time.time() engine = builder.build_engine(network, builder_config) build_time_elapsed = (time.time() - build_start_time) TRT_LOGGER.log(TRT_LOGGER.INFO, "build engine in {:.3f} Sec".format(build_time_elapsed)) if config.use_int8 and not config.use_qat: calibrator.free() return engine
def build_engine(batch_sizes, workspace_size, sequence_lengths, config, weights_dict, squad_json, vocab_file, calibrationCacheFile, calib_num): explicit_batch_flag = 1 << int( trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) with trt.Builder(TRT_LOGGER) as builder, builder.create_network( explicit_batch_flag) as network, builder.create_builder_config( ) as builder_config: builder_config.max_workspace_size = workspace_size * (1024 * 1024) if config.use_fp16: builder_config.set_flag(trt.BuilderFlag.FP16) if config.use_int8: builder_config.set_flag(trt.BuilderFlag.INT8) if not config.use_qat: calibrator = BertCalibrator(squad_json, vocab_file, calibrationCacheFile, 1, sequence_lengths[-1], calib_num) builder_config.set_quantization_flag( trt.QuantizationFlag.CALIBRATE_BEFORE_FUSION) builder_config.int8_calibrator = calibrator if config.use_strict: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) if config.use_sparsity: TRT_LOGGER.log(TRT_LOGGER.INFO, "Setting sparsity flag on builder_config.") builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS) # speed up the engine build for trt major version >= 8 # 1. disable cudnn tactic # 2. load global timing cache if trt_version[0] >= 8: tactic_source = 1 << int(trt.TacticSource.CUBLAS) | 1 << int( trt.TacticSource.CUBLAS_LT) builder_config.set_tactic_sources(tactic_source) if config.timing_cache != None: if os.path.exists(config.timing_cache): with open(config.timing_cache, "rb") as f: cache = builder_config.create_timing_cache(f.read()) builder_config.set_timing_cache(cache, ignore_mismatch=False) else: cache = builder_config.create_timing_cache(b"") builder_config.set_timing_cache(cache, ignore_mismatch=False) # only use the largest sequence when in calibration mode if config.is_calib_mode: sequence_lengths = sequence_lengths[-1:] # Create the network emb_layer = emb_layernorm(builder, network, config, weights_dict, builder_config, sequence_lengths, batch_sizes) embeddings = emb_layer.get_output(0) mask_idx = emb_layer.get_output(1) bert_out = bert_model(config, weights_dict, network, embeddings, mask_idx) squad_logits = squad_output("cls_", config, weights_dict, network, bert_out) squad_logits_out = squad_logits.get_output(0) network.mark_output(squad_logits_out) build_start_time = time.time() engine = builder.build_engine(network, builder_config) build_time_elapsed = (time.time() - build_start_time) TRT_LOGGER.log(TRT_LOGGER.INFO, "build engine in {:.3f} Sec".format(build_time_elapsed)) # save global timing cache if trt_version[0] >= 8 and config.timing_cache != None: cache = builder_config.get_timing_cache() with cache.serialize() as buffer: with open(config.timing_cache, "wb") as f: f.write(buffer) f.flush() os.fsync(f) if config.use_int8 and not config.use_qat: calibrator.free() return engine