def makeModel( baseModelBuilder, containStaticOutput=False, listQuant=[], # [bool, ...] listCntQuantPhase=[], # [int, ...] listGate=[], # [[float, ...], ...] listfilterGate=[], listQuant_input_minmax=[], # [[int, int, ...], ...] listQuant_filter_minmax=[], # [[int, int], ...] listNum_bits=[], # [[int, ...], ...] ): # Set parameter quantControl.setQuantParamList( listQuant=listQuant, listCntQuantPhase=listCntQuantPhase, listGate=listGate, listfilterGate=listfilterGate, listQuant_input_minmax=listQuant_input_minmax, listQuant_filter_minmax=listQuant_filter_minmax, listNum_bits=listNum_bits, ) quantControl.clearConvQuantInfo() # Build model model = baseModelBuilder() # Modify the output if containStaticOutput: ori_inputs = model.inputs ori_outputs = model.outputs listPercentageOutput, listStaticInfoOutput = quantControl.getConvQuantInfo( ) new_outputs = model.outputs + listPercentageOutput + listStaticInfoOutput model = keras.Model(inputs=ori_inputs, outputs=new_outputs) return model
def addModelStaticOutput(model, quantControl) : ori_inputs = model.inputs ori_outputs = model.outputs # 在验证时输出统计消息 listPercentageOutput, listStaticInfoOutput = quantControl.getConvQuantInfo() new_outputs = model.outputs + listPercentageOutput + listStaticInfoOutput # 重新构造模型 model = Model(inputs=ori_inputs, outputs=new_outputs) return model