def Cast( graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict, ): inputsRef = curNode.getInputsRef() assert len(inputsRef) == 1 sourceType = curNode.getAttrMapRef()["SrcT"].getDataType() destType = curNode.getAttrMapRef()["DstT"].getDataType() return ( None, { curNode.getName(): AST.UninterpFuncCall( extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Cast.name, [ AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(sourceType.name), AST.ID(destType.name), ], ) }, )
def Const(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): assert(len(curNode.getInputsRef()) == 0) tensor = curNode.getAttrMapRef()["value"].getTensor() curNodeDataType = curNode.getAttrMapRef()["dtype"].getDataType() curNodeShape = tensor.getShapeRef()[:] #create a different copy to not change the original copy tensorConstantVal = tensor.getConstantVal() if tensorConstantVal is not None: # Use uinterpreted call of CreateTensor to create the tensor and fill it with a constant value dataPassed = None if curNodeDataType == Graph.DataTypeEnum.DT_INT32: dataPassed = AST.Int(tensorConstantVal, 32, isSecret=False) elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT: dataPassed = AST.Float(tensorConstantVal, isSecret=False) else: assert False if (len(curNodeShape) == 0): # This is a constant element retAST = dataPassed else: retAST = AST.UninterpFuncCall(curNodeShape, TFNodesAST.UninterpFuncCallNames.CreateTensor.name, [dataPassed], isSecret=False) else: # The tensor content is given as byte array. Extract val array from the byte array and create ast. if curNodeDataType == Graph.DataTypeEnum.DT_INT32: dataPassed = list(map(lambda x: AST.Int(x, 32, isSecret=False), tensor.getContentAsValArr()[:])) elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT: dataPassed = list(map(lambda x: AST.Float(x, isSecret=False), tensor.getContentAsValArr()[:])) else: assert False retAST = AST.Decl(curNodeShape, None, dataPassed, isSecret=False) return (None, { curNode.getName() : retAST})
def helper_processPool( graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict, typeOfPool: str, ): inputsRef = curNode.getInputsRef() assert len(inputsRef) == 1 options = {} stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() assert (stridesUsed[0] == 1) and (stridesUsed[3] == 1) strideH = stridesUsed[1] strideW = stridesUsed[2] kSizeUsed = curNode.getAttrMapRef()["ksize"].getList().getILi() assert (kSizeUsed[0] == 1) and (kSizeUsed[3] == 1) kSizeH = kSizeUsed[1] kSizeW = kSizeUsed[2] inputShape = extraNodeInfoDict[inputsRef[0]][0] imgH = inputShape[1] imgW = inputShape[2] paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, kSizeH, kSizeW, strideH, strideW, paddingUsedStr) poolType = None if typeOfPool == "MAXPOOL": poolType = AST.Pool.PoolType.MaxPool elif typeOfPool == "AVGPOOL": poolType = AST.Pool.PoolType.AvgPool else: print("Unknown type of pooling layer.", file=sys.stderr) assert False return ( None, { curNode.getName(): AST.Pool( poolType, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), { AST.PaddingKeysDict.FH: kSizeH, AST.PaddingKeysDict.FW: kSizeW, AST.PaddingKeysDict.zPadHLeft: zPadHLeft, AST.PaddingKeysDict.zPadHRight: zPadHRight, AST.PaddingKeysDict.zPadWLeft: zPadWLeft, AST.PaddingKeysDict.zPadWRight: zPadWRight, AST.PaddingKeysDict.strideH: strideH, AST.PaddingKeysDict.strideW: strideW, }, ) }, )
def VariableV2( graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict, ): curNodeShapeLi = curNode.getAttrMapRef()["shape"].getShape().getDimRef( )[:] curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType() # NOTE : since this becomes an input node right now, i have also added to be prefixed at top in ProcessTFGraph::prefixAllPlaceHolderNodes() # NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code # (in the scenario of secure inference), model is input by server and image by client. # We assume in the following that the PlaceHolder op node represents the image and # all model parameters are represented using Variable op nodes. # Hence, in the call to AST.Input, we pass inputByParty as SERVER. return ( None, { curNode.getName(): AST.Input( curNodeShapeLi, curNodeInputType.name, isSecret=True, inputByParty=AST.Party.SERVER, ) }, )
def Pad(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): # Mode refers to 'CONSTANT', 'REFLECT' or 'SYMMETRIC' mode = 0 if ("\"mode\"" in curNode.getAttrMapRef()): mode = curNode.getAttrMapRef()["\"mode\""].getI() constant_values = 0 if ("\"constant_values\"" in curNode.getAttrMapRef()): constant_values = curNode.getAttrMapRef( )["\"constant_values\""].getI() assert ( mode == 0 and constant_values == 0 ) # For now to make life easy - deal with SYMMETRIC AND REFLECT when time comes inputsRef = curNode.getInputsRef() inputTensorShapeLi = extraNodeInfoDict[inputsRef[0]][0] return (None, AST.UninterpFuncCall( extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Pad.name, [ AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) ], outputDiffInpDims=1))
def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==3) #output_shape, filter, input stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() assert(stridesUsed[0]==1 and stridesUsed[4]==1) strideD = stridesUsed[1] strideH = stridesUsed[2] strideW = stridesUsed[3] filterShape = extraNodeInfoDict[inputsRef[1]][0] FD = filterShape[0] FH = filterShape[1] FW = filterShape[2] inputShape = extraNodeInfoDict[inputsRef[2]][0] inputD = inputShape[1] inputH = inputShape[2] inputW = inputShape[3] outputShape = extraNodeInfoDict[curNode.getName()][0] outputD = outputShape[1] outputH = outputShape[2] outputW = outputShape[3] paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() # Important: Using outputH and outputW in the below is not an error! # For convTranspose, the parameters passed in the node are of the conv of which this convTranspose is an inverse. # Which is why the call to helper_findPadding makes sense. # The zPads below are of the conv of which this convTranspose is an inverse. [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(outputH, outputW, FH, FW, strideH, strideW, paddingUsedStr, imgD = outputD, FD = FD, strideD = strideD) options = {} options[AST.PaddingKeysDict.FD] = FD options[AST.PaddingKeysDict.FH] = FH options[AST.PaddingKeysDict.FW] = FW options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft options[AST.PaddingKeysDict.zPadDRight] = zPadDRight options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft options[AST.PaddingKeysDict.zPadHRight] = zPadHRight options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft options[AST.PaddingKeysDict.zPadWRight] = zPadWRight options[AST.PaddingKeysDict.strideD] = strideD options[AST.PaddingKeysDict.strideH] = strideH options[AST.PaddingKeysDict.strideW] = strideW options[AST.PaddingKeysDict.ConvDim] = 3 options[AST.PaddingKeysDict.outputImgD] = outputD options[AST.PaddingKeysDict.outputImgH] = outputH options[AST.PaddingKeysDict.outputImgW] = outputW return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]), TFNodesAST.getOperatorsIdx('#T'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), options)})
def VariableV2(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): curNodeShapeLi = curNode.getAttrMapRef()["\"shape\""].getShape( ).getDimRef()[:] curNodeInputType = curNode.getAttrMapRef()["\"dtype\""].getDataType() #########TODO_TAB : for inference, have commented out decl and inserted input nodes. # TODO : Right now in the current implementation, the dataType being passed to the node is being ignored by SeeDot. # Fix this later. # return (None, AST.Decl(curNodeShapeLi, curNodeInputType.name, None)) # NOTE : since this becomes an input node right now, i have also added to be prefixed at top in ProcessTFGraph::prefixAllPlaceHolderNodes() return (None, AST.Input(curNodeShapeLi, curNodeInputType.name))
def Conv2D( graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict, ): inputsRef = curNode.getInputsRef() assert len(inputsRef) == 2 stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() assert stridesUsed[0] == 1 and stridesUsed[3] == 1 strideH = stridesUsed[1] strideW = stridesUsed[2] inputShape = extraNodeInfoDict[inputsRef[0]][0] imgH = inputShape[1] imgW = inputShape[2] filterShape = extraNodeInfoDict[inputsRef[1]][0] FH = filterShape[0] FW = filterShape[1] paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() [zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr) options = {} options[AST.PaddingKeysDict.FH] = FH options[AST.PaddingKeysDict.FW] = FW options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft options[AST.PaddingKeysDict.zPadHRight] = zPadHRight options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft options[AST.PaddingKeysDict.zPadWRight] = zPadWRight options[AST.PaddingKeysDict.strideH] = strideH options[AST.PaddingKeysDict.strideW] = strideW return ( None, { curNode.getName(): AST.BOp( AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx("#"), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), options, ) }, )
def Conv3D(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): inputsRef = curNode.getInputsRef() assert (len(inputsRef) == 2) stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi() assert (stridesUsed[0] == 1 and stridesUsed[4] == 1) strideD = stridesUsed[1] strideH = stridesUsed[2] strideW = stridesUsed[3] inputShape = extraNodeInfoDict[inputsRef[0]][0] imgD = inputShape[1] imgH = inputShape[2] imgW = inputShape[3] filterShape = extraNodeInfoDict[inputsRef[1]][0] FD = filterShape[0] FH = filterShape[1] FW = filterShape[2] paddingUsedStr = curNode.getAttrMapRef()["padding"].getS() [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW, strideH, strideW, paddingUsedStr, imgD, FD, strideD) options = {} options[AST.PaddingKeysDict.FD] = FD options[AST.PaddingKeysDict.FH] = FH options[AST.PaddingKeysDict.FW] = FW options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft options[AST.PaddingKeysDict.zPadDRight] = zPadDRight options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft options[AST.PaddingKeysDict.zPadHRight] = zPadHRight options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft options[AST.PaddingKeysDict.zPadWRight] = zPadWRight options[AST.PaddingKeysDict.strideD] = strideD options[AST.PaddingKeysDict.strideH] = strideH options[AST.PaddingKeysDict.strideW] = strideW options[AST.PaddingKeysDict.ConvDim] = 3 return (None, { curNode.getName(): AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('#'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), options) })
def Mean(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): inputsRef = curNode.getInputsRef() attrMapRef = curNode.getAttrMapRef() assert (len(inputsRef) == 2) keepdims = False if ("keep_dims" in attrMapRef): keepdims = attrMapRef["keep_dims"].getB() reductionAxesNodeName = inputsRef[1] redAxesN = graph.__getitem__(reductionAxesNodeName) redAxesT = redAxesN.getAttrVal("value").getTensor() rank = redAxesT.getShapeRef().getRank() if rank != 0: reductionAxesList = redAxesT.getContentAsValArr() else: reductionAxesList = [redAxesT.getConstantVal()] curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] return (None, { curNode.getName(): AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), keepdims, curNodeShapeLi, TFNodesAST.getOperatorsIdx('mean'), reductionAxesList) })
def Placeholder( graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict, ): curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] curNodeInputType = curNode.getAttrMapRef()["dtype"].getDataType() assert curNodeInputType is not Graph.DataTypeEnum.DT_INVALID # NOTE: There has to be some way for Athos to differentiate model from image, since in the compiled code # (in the scenario of secure inference), model is input by server and image by client. # We assume in the following that the PlaceHolder op node represents the image and # all model parameters are represented using Variable op nodes. # Hence, in the call to AST.Input, we pass inputByParty=1. return ( None, { curNode.getName(): AST.Input( curNodeShapeLi, curNodeInputType.name, isSecret=True, inputByParty=AST.Party.CLIENT, ) }, )
def MatMul( graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict, ): inputsRef = curNode.getInputsRef() assert len(inputsRef) == 2 inp1Str = dictNodeNameToOutVarStr[inputsRef[0]] inp2Str = dictNodeNameToOutVarStr[inputsRef[1]] inp1AST = AST.ID(inp1Str) inp2AST = AST.ID(inp2Str) attrMapRef = curNode.getAttrMapRef() transposeABool = transposeBBool = False if "transpose_a" in attrMapRef: transposeABool = attrMapRef["transpose_a"].getB() if "transpose_b" in attrMapRef: transposeBBool = attrMapRef["transpose_b"].getB() if transposeABool: inp1AST = AST.Transp(inp1AST) if transposeBBool: inp2AST = AST.Transp(inp2AST) return ( None, { curNode.getName(): AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx("*"), inp2AST) }, )
def Squeeze( graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict, ): inputsRef = curNode.getInputsRef() inputTensorShape = extraNodeInfoDict[inputsRef[0]][0] inputTensorRank = len(inputTensorShape) squeezeDims = curNode.getAttrMapRef()["squeeze_dims"].getList().getILi( ) squeezeDimsRank = len(squeezeDims) return ( None, { curNode.getName(): AST.UninterpFuncCall( extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Squeeze.name, list( map(lambda x: AST.Int(x, 32, isSecret=False), squeezeDims)) + [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])], ) }, )
def Placeholder(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): #curNodeShapeLi = curNode.getAttrMapRef()["\"shape\""].getShape().getDimRef() curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] curNodeInputType = curNode.getAttrMapRef()["\"dtype\""].getDataType() assert (curNodeInputType is not Graph.DataTypeEnum.DT_INVALID) #TODO : There has to be some way to take range, understand the dimensions for SeeDot return (None, AST.Input(curNodeShapeLi, curNodeInputType.name))
def ZerosLike(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType() assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID) retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.CreateTensor.name, [AST.Int(0, isSecret=False)], isSecret=False) return (None, { curNode.getName() : retAST})
def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 3) curNodeDataType = curNode.getAttrMapRef()["\"T\""].getDataType() retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.CreateCopy.name, [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), # of this AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), # begin idx AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]) # size ]) return (None, retAST)
def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() attrMapRef = curNode.getAttrMapRef() assert(len(inputsRef) == 2) keepdims = False if ("\"keep_dims\"" in attrMapRef): keepdims = attrMapRef["\"keep_dims\""].getB() curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0] return (None, AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), AST.Int(int(keepdims), 32, isSecret=False), curNodeShapeLi, TFNodesAST.getOperatorsIdx('mean')))
def ConcatV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() N = curNode.getAttrMapRef()["N"].getI() assert(len(inputsRef) == N+1) #One extra for axis #TODO : Since the axis of concat is constant, therefore, its known here - the input's sizes along that dim should be # passed as input to the below function. # For now hardcoding. retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.Concat.name + str(N) + 'T', list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)), outputDiffInpDims=1 ) return (None, { curNode.getName() : retAST})
def Fill(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) curNodeOutputShape = extraNodeInfoDict[inputsRef[0]][0] assert(len(curNodeOutputShape) == 1) #inputsRef[0] denotes a shape and should have a rank of 1 curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType() assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID) retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.CreateTensor.name, [AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) ], isSecret=False) return (None, { curNode.getName() : retAST})
def Identity(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): #In SeeDot, J2=J1 creates a new reference for J1 -- so # the corresponding code in Seedot cannot simply be J2 = J1. # Instead create a new tensor first and then assign the old one to the new one. inputsRef = curNode.getInputsRef() assert(len(inputsRef)==1) curNodeDataType = curNode.getAttrMapRef()["T"].getDataType() assert(curNodeDataType is not Graph.DataTypeEnum.DT_INVALID) curNodeShape = extraNodeInfoDict[curNode.getName()][0] retAST = AST.UninterpFuncCall(curNodeShape, TFNodesAST.UninterpFuncCallNames.CreateIdentity.name, [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])]) return (None, { curNode.getName() : retAST})
def MatMul(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict): inputsRef = curNode.getInputsRef() assert(len(inputsRef) == 2) inp1Str = dictNodeNameToOutVarStr[inputsRef[0]] inp2Str = dictNodeNameToOutVarStr[inputsRef[1]] inp1AST = AST.ID(inp1Str) inp2AST = AST.ID(inp2Str) attrMapRef = curNode.getAttrMapRef() transposeABool = transposeBBool = False if ("\"transpose_a\"" in attrMapRef): transposeABool = attrMapRef["\"transpose_a\""].getB() if ("\"transpose_b\"" in attrMapRef): transposeBBool = attrMapRef["\"transpose_b\""].getB() if (transposeABool): inp1AST = AST.Transp(inp1AST) if (transposeBBool): inp2AST = AST.Transp(inp2AST) return (None, AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx('*'), inp2AST))