def PrintOperations(biasOp, webnnOpType, webnnParamsStr, fusedReluMappedInfo, outputOp, test): if biasOp is not None: IndentedPrint("const interOut0 = builder.%s(%s);" % \ (webnnOpType, webnnParamsStr), indent=4, file=test) if fusedReluMappedInfo is not None: # Add 'add' operation IndentedPrint("const interOut1 = builder.add(interOut0, %s);" % \ biasOp, indent=4, file=test) # Add 'relu' or 'clamp' operation PrintMappedReluOpertions(fusedReluMappedInfo[1], outputOp, 'interOut1') else: # Add 'add' operation IndentedPrint("const %s = builder.add(interOut0, %s);" % \ (outputOp, biasOp), indent=4, file=test) else: if fusedReluMappedInfo is not None: if fusedReluMappedInfo[0]: IndentedPrint("const interOut0 = builder.%s(%s);" % \ (webnnOpType, webnnParamsStr), indent=4, file=test) # Add 'relu' or 'clamp' operation PrintMappedReluOpertions(fusedReluMappedInfo[1], outputOp, 'interOut0') else: PrintMappedReluOpertions(fusedReluMappedInfo[1], outputOp, webnnParamsStr) else: IndentedPrint("const %s = builder.%s(%s);" % \ (outputOp, webnnOpType, webnnParamsStr), indent=4, file=test)
def PrintMappedReluOpertions(fusedReluMappedInfo, outputOp, operandName): mappedWebNNOpName = fusedReluMappedInfo['name'] options = fusedReluMappedInfo.get('options', None) if options is None: IndentedPrint("const %s = builder.%s(%s);" % \ (outputOp, mappedWebNNOpName, operandName), indent=4, file=test) else: IndentedPrint("const %s = builder.%s(%s, %s);" % \ (outputOp, mappedWebNNOpName, operandName, options), indent=4, file=test)
def DumpAllInOneCtsTest(test, cts): versionList = os.listdir(test) with SmartOpen(cts, mode="a") as aioTest: InitializeCtsTestFile(aioTest, 3) for version in sorted(versionList): versionPath = os.path.join(test, version) for generatedTest in sorted(os.listdir(versionPath)): generatedTestPath = os.path.join(versionPath, generatedTest) with SmartOpen(generatedTestPath, mode="r") as readFile: fileText = readFile.readlines() for (lineNum, lineText) in enumerate(fileText): if lineNum in range(6, len(fileText) - 2): aioTest.write(lineText) IndentedPrint("});", file=aioTest) IndentedPrint("/* eslint-disable max-len */", file=aioTest)
def PrintInputData(oprand, operation, opInsList, name, value, layout, test): typedArray = oprand.type.mappingTypedArrayType opValue = GetOperandValue(oprand, operation, opInsList, name, layout, value) IndentedPrint('const %sData = new %s(%s);' % (oprand, typedArray, opValue), indent=4, file=test)
def PrintConstant(oprand, operation, opInsList, opInsInfoList, name, layout, test): opDesc = GetWebNNOperandDesc(oprand, operation, opInsList, opInsInfoList, layout) opValue = GetOperandValue(oprand, operation, opInsList, name, layout) operand = "const %s = builder.constant(%s, new %s(%s));" % \ (oprand, opDesc, oprand.type.mappingTypedArrayType, opValue) IndentedPrint(operand, indent=4, file=test)
def DumpCtsTest(example, test): model = example.model if len(model.operations) > 1: msg = 'Not convert complicated tests with multi-operations' # print(msg, file=sys.stderr) return nnapiOp = model.operations[0].optype if nnapiOp not in md.MappingDict.keys(): msg = msgTemplate % (nnapiOp, 'none mapped WebNN Opeartion') # print(msg, file=sys.stderr) return # WebNN polyfill API cur supports 'int32' and 'float32' unSupportedTypesList = ['int8', 'uint8', 'float16'] operandTypeList = model.GetMappedOperandTypes() usedUnsupportedType = list( set(unSupportedTypesList) & set(operandTypeList)) if len(usedUnsupportedType) > 0: msg = msgTemplate % \ (nnapiOp, 'unsupported %s Operand Type' % usedUnsupportedType) # print(msg, file=sys.stderr) return mappingOpDict = md.MappingDict[nnapiOp] mappedWebNNOp = mappingOpDict['webnnOperation'] if Configuration.successedCounter == 0: # Update mappingWebNNOp by first time Configuration.mappingWebNNOp.append(mappedWebNNOp) nnapiOpInsList = copy.deepcopy(mappingOpDict['insList']) nnapiOpOptionalInsList = mappingOpDict.get('optionalInsList', []) curOpInsList = model.operations[0].ins if CheckOperationWithImplicitPadding(nnapiOp, curOpInsList, nnapiOpInsList, len(nnapiOpOptionalInsList)): ClearMappingWebNNOpConfiguration() return if GetOperandIndex(nnapiOpInsList, 'bias') != -1: if Configuration.successedCounter == 0: Configuration.mappingWebNNOp.append('add') curInputsList = example.model.GetInputs() curOutputsList = example.model.GetOutputs() curParamsList = example.model.GetParameters() fusedReluMappedInfo = None actIndex = GetOperandIndex(nnapiOpInsList, 'activation') actStatus, actValue = GetParamOperandValue(curParamsList, curOpInsList, actIndex) if actStatus: UpdateMappingWebNNOpList(actValue[0]) fusedReluMappedInfo = (True, GetReluMappedInfo(actValue[0])) if nnapiOp == 'RELU1': fusedReluMappedInfo = (False, GetReluMappedInfo(2)) if nnapiOp == 'RELU6': fusedReluMappedInfo = (False, GetReluMappedInfo(3)) nnapiOpInsList.extend(nnapiOpOptionalInsList) layoutIndex = GetOperandIndex(nnapiOpInsList, 'layout') layoutStatus, layoutValue = GetParamOperandValue(curParamsList, curOpInsList, layoutIndex) # True: 'nchw', False: 'nhwc' layout = False if not layoutStatus else layoutValue[0] if nnapiOp == 'DEPTHWISE_CONV_2D': if not SupportedConvertDepthwiseConv2D(curInputsList[0], curOutputsList[0], layout): ClearMappingWebNNOpConfiguration() return biasOp = None testIndex = 1 if len(example.feedDicts) > 1 else 0 for inputFeedDict, outputFeedDict in example.feedDicts: if nnapiOp == 'SOFTMAX': if not SupportedConvertSoftmax(nnapiOpInsList, curInputsList[0], curParamsList, curOpInsList): ClearMappingWebNNOpConfiguration() return IndentedPrint("", file=test) # Add blank line testPurpose = 'test %s converted from %s test' % \ (' + '.join(Configuration.mappingWebNNOp), str(example.testName)) if testIndex > 0: testPurpose = "%s/%d" % (testPurpose, testIndex) IndentedPrint("it('%s', async function() {" % testPurpose, indent=2, file=test) IndentedPrint("// Converted test case (from: %s/%s)" % \ (tg.FileNames.version, os.path.basename(tg.FileNames.specFile)), indent=4, file=test) IndentedPrint("const builder = new MLGraphBuilder(context);", indent=4, file=test) computeParamsList = [] # Create operand(s) by ModelBuilder.input for op in curInputsList: opInsDict = nnapiOpInsList[curOpInsList.index(op)] mappingParamIndex = opInsDict['mappingParamIndex'] if mappingParamIndex != -1: rule = md.MappingRule(opInsDict['mappingRuleType']) if rule == md.MappingRule.OPERAND_OPERAND: PrintInputOperand(op, nnapiOp, curOpInsList, nnapiOpInsList, layout, test) PrintInputData(op, nnapiOp, curOpInsList, opInsDict['name'], inputFeedDict[op], layout, test) computeParamsList.append("'%s': {data: %sData}" % \ (op, op)) elif rule == md.MappingRule.OPERAND_VARIABLE: varValue = inputFeedDict[op] if len(varValue) != 0 and varValue[0] is not None: IndentedPrint('const %s = %s;' % (op, varValue), indent=4, file=test) elif rule == md.MappingRule.OPERAND_ARRAY: varValue = inputFeedDict[op] if len(varValue) != 0: IndentedPrint('const %s = %s;' % (op, varValue), indent=4, file=test) else: if opInsDict['name'] == 'bias': biasOp = op PrintInputOperand(op, nnapiOp, curOpInsList, nnapiOpInsList, layout, test) PrintInputData(op, nnapiOp, curOpInsList, opInsDict['name'], inputFeedDict[op], layout, test) computeParamsList.append("'%s': {data: %sData}" % \ (op, op)) # Create operand(s) by ModelBuilder.constant, or define variable(s) for op in curParamsList: opInsDict = nnapiOpInsList[curOpInsList.index(op)] mappingParamIndex = opInsDict['mappingParamIndex'] if mappingParamIndex != -1: rule = md.MappingRule(opInsDict['mappingRuleType']) if rule == md.MappingRule.OPERAND_OPERAND: PrintConstant(op, nnapiOp, curOpInsList, nnapiOpInsList, opInsDict['name'], layout, test) elif rule == md.MappingRule.VARIABLE_VARIABLE: varValue = curParamsList[curParamsList.index(op)].value[0] if opInsDict['name'] == 'layout': if varValue: varValue = "'nchw'" else: varValue = "'nhwc'" IndentedPrint('const %s = %s;' % (op, varValue), indent=4, file=test) elif rule == md.MappingRule.OPERAND_VARIABLE: varValue = curParamsList[curParamsList.index(op)].value if len(varValue) != 0 and varValue[0] is not None: IndentedPrint('const %s = %s;' % (op, varValue), indent=4, file=test) elif rule == md.MappingRule.OPERAND_ARRAY: varValue = curParamsList[curParamsList.index(op)].value if len(varValue) != 0: IndentedPrint('const %s = %s;' % (op, varValue), indent=4, file=test) else: if opInsDict['name'] == 'bias': biasOp = op PrintConstant(op, nnapiOp, curOpInsList, nnapiOpInsList, opInsDict['name'], layout, test) if len(curOutputsList) == 1: outputOp = curOutputsList[0] IndentedPrint("const expected = %s;" % outputFeedDict[outputOp], indent=4, file=test) elif len(curOutputsList) > 1: outputOp = curOutputsList expectedValueList = [outputFeedDict[k] for k in outputOp] IndentedPrint("const expected = %s;" % expectedValueList, indent=4, file=test) # Update optional parameter value optionsKeyValueList = [] hasLayoutOption = False for optionalIns in nnapiOpOptionalInsList: if optionalIns['name'] == 'layout': hasLayoutOption = True break if hasLayoutOption: if not layout: # Default 'nchw' layout with WebNN API optionsKeyValueList.append(('layout', False)) if nnapiOp == 'DEPTHWISE_CONV_2D': # True: 'nchw' False: 'nhwc' chanelIndex = 1 if layout else 3 groups = outputOp.type.dimensions[chanelIndex] optionsKeyValueList.append(('groups', groups)) mappingParams = GetWebNNOperationParamsList(nnapiOpInsList, curOpInsList, inputFeedDict, curParamsList, nnapiOp) UpdateWebNNOperationOptionalParamValue(nnapiOp, mappingParams[-1][1], optionsKeyValueList, layout) webnnParamsStr = GetWebNNParamsString(mappingParams) if nnapiOp == 'SQRT': exponent = "const exponent = builder.constant({type: 'float32'," + \ " dimensions: [1]}, new Float32Array([0.5]));" IndentedPrint(exponent, indent=4, file=test) webnnParamsStr = ', '.join([webnnParamsStr, 'exponent']) if nnapiOp in ['CONV_2D', 'DEPTHWISE_CONV_2D']: webnnParamsStr = webnnParamsStr.replace("'layout'", "'inputLayout'") PrintOperations(biasOp, mappedWebNNOp, webnnParamsStr, fusedReluMappedInfo, outputOp, test) if len(curOutputsList) == 1: IndentedPrint("const graph = await{%s});" % outputOp, indent=4, file=test) elif len(curOutputsList) > 1: outputOpNameList = [ for item in outputOp] IndentedPrint("const graph = await{%s});" % \ ', '.join(outputOpNameList), indent=4, file=test) IndentedPrint("const outputs = await graph.compute({%s});" % \ ', '.join(computeParamsList), indent=4, file=test) # Check compute output criteria = 'utils.ctsFp32RestrictAccuracyCriteria' if model.isRelaxed: criteria = 'utils.ctsFp32RelaxedAccuracyCriteria' if len(curOutputsList) == 1: IndentedPrint( "utils.checkValue(, expected, %s);" % \ (outputOp, criteria), indent=4, file=test) elif len(curOutputsList) > 1: IndentedPrint('for (let i = 0; i < %d; i++) {' % \ len(curOutputsList), indent=4, file=test) dataStr = 'outputs[%s[i]].data' % ['%s' % k for k in outputOp] IndentedPrint( "utils.checkValue(%s, expected[i], %s);" % \ (dataStr, criteria), indent=6, file=test) IndentedPrint("}", indent=4, file=test) IndentedPrint("});", indent=2, file=test) testIndex += 1 Configuration.successedCounter += 1
def PrintInputOperand(oprand, operation, opInsList, opInsInfoList, layout, test): opDesc = GetWebNNOperandDesc(oprand, operation, opInsList, opInsInfoList, layout) operand = "const %s = builder.input('%s', %s);" % (oprand, oprand, opDesc) IndentedPrint(operand, indent=4, file=test)
IndentedPrint("});", indent=2, file=test) testIndex += 1 Configuration.successedCounter += 1 if __name__ == '__main__': ParseCmdLine() while tg.FileNames.NextFile(): Configuration.mappingWebNNOp = [] Configuration.successedCounter = 0 # print("Generating test(s) from spec: %s" % tg.FileNames.specFile, # file=sys.stderr) exec(open(tg.FileNames.specFile, "r").read()) testFile = tg.FileNames.testFile with SmartOpen(testFile) as test: InitializeCtsTestFile(test, 4) Example.DumpAllExamples(DumpTest=DumpCtsTest, test=test) IndentedPrint("});", file=test) IndentedPrint("/* eslint-disable max-len */", file=test) if Configuration.successedCounter == 0: os.remove(testFile) else: newName = 'test_%s_converted_from_%s' % \ ('_'.join(Configuration.mappingWebNNOp), os.path.basename(testFile)) renamedFile = os.path.join(os.path.dirname(testFile), newName) os.rename(testFile, renamedFile) # print("Successfully generated CTS test: %s" % renamedFile, # file=sys.stderr)