예제 #1
0
    def helper_processPool(
        graph: Graph.Graph,
        curNode: Graph.Node,
        dictNodeNameToOutVarStr: dict,
        extraNodeInfoDict: dict,
        typeOfPool: str,
    ):
        inputsRef = curNode.getInputsRef()
        assert len(inputsRef) == 1

        options = {}

        stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi()
        assert (stridesUsed[0] == 1) and (stridesUsed[3] == 1)
        strideH = stridesUsed[1]
        strideW = stridesUsed[2]

        kSizeUsed = curNode.getAttrMapRef()["ksize"].getList().getILi()
        assert (kSizeUsed[0] == 1) and (kSizeUsed[3] == 1)
        kSizeH = kSizeUsed[1]
        kSizeW = kSizeUsed[2]

        inputShape = extraNodeInfoDict[inputsRef[0]][0]
        imgH = inputShape[1]
        imgW = inputShape[2]

        paddingUsedStr = curNode.getAttrMapRef()["padding"].getS()
        [zPadHLeft, zPadHRight, zPadWLeft,
         zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, kSizeH,
                                                     kSizeW, strideH, strideW,
                                                     paddingUsedStr)

        poolType = None
        if typeOfPool == "MAXPOOL":
            poolType = AST.Pool.PoolType.MaxPool
        elif typeOfPool == "AVGPOOL":
            poolType = AST.Pool.PoolType.AvgPool
        else:
            print("Unknown type of pooling layer.", file=sys.stderr)
            assert False
        return (
            None,
            {
                curNode.getName():
                AST.Pool(
                    poolType,
                    AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                    {
                        AST.PaddingKeysDict.FH: kSizeH,
                        AST.PaddingKeysDict.FW: kSizeW,
                        AST.PaddingKeysDict.zPadHLeft: zPadHLeft,
                        AST.PaddingKeysDict.zPadHRight: zPadHRight,
                        AST.PaddingKeysDict.zPadWLeft: zPadWLeft,
                        AST.PaddingKeysDict.zPadWRight: zPadWRight,
                        AST.PaddingKeysDict.strideH: strideH,
                        AST.PaddingKeysDict.strideW: strideW,
                    },
                )
            },
        )
예제 #2
0
 def Shape(graph: Graph.Graph, curNode: Graph.Node,
           dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 1)
     return (None,
             AST.Func(TFNodesAST.getOperatorsIdx('shape'),
                      AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])))
예제 #3
0
    def MatMul(
        graph: Graph.Graph,
        curNode: Graph.Node,
        dictNodeNameToOutVarStr: dict,
        extraNodeInfoDict: dict,
    ):
        inputsRef = curNode.getInputsRef()
        assert len(inputsRef) == 2
        inp1Str = dictNodeNameToOutVarStr[inputsRef[0]]
        inp2Str = dictNodeNameToOutVarStr[inputsRef[1]]
        inp1AST = AST.ID(inp1Str)
        inp2AST = AST.ID(inp2Str)

        attrMapRef = curNode.getAttrMapRef()
        transposeABool = transposeBBool = False
        if "transpose_a" in attrMapRef:
            transposeABool = attrMapRef["transpose_a"].getB()
        if "transpose_b" in attrMapRef:
            transposeBBool = attrMapRef["transpose_b"].getB()
        if transposeABool:
            inp1AST = AST.Transp(inp1AST)
        if transposeBBool:
            inp2AST = AST.Transp(inp2AST)
        return (
            None,
            {
                curNode.getName():
                AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx("*"), inp2AST)
            },
        )
예제 #4
0
 def Reshape(graph: Graph.Graph, curNode: Graph.Node,
             dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 2)
     return (None,
             AST.Reshape(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                         extraNodeInfoDict[curNode.getName()][0], None))
예제 #5
0
 def Cast(
     graph: Graph.Graph,
     curNode: Graph.Node,
     dictNodeNameToOutVarStr: dict,
     extraNodeInfoDict: dict,
 ):
     inputsRef = curNode.getInputsRef()
     assert len(inputsRef) == 1
     sourceType = curNode.getAttrMapRef()["SrcT"].getDataType()
     destType = curNode.getAttrMapRef()["DstT"].getDataType()
     return (
         None,
         {
             curNode.getName():
             AST.UninterpFuncCall(
                 extraNodeInfoDict[curNode.getName()][0],
                 TFNodesAST.UninterpFuncCallNames.Cast.name,
                 [
                     AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                     AST.ID(sourceType.name),
                     AST.ID(destType.name),
                 ],
             )
         },
     )
예제 #6
0
    def Pad(graph: Graph.Graph, curNode: Graph.Node,
            dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        # Mode refers to 'CONSTANT', 'REFLECT' or 'SYMMETRIC'
        mode = 0
        if ("\"mode\"" in curNode.getAttrMapRef()):
            mode = curNode.getAttrMapRef()["\"mode\""].getI()

        constant_values = 0
        if ("\"constant_values\"" in curNode.getAttrMapRef()):
            constant_values = curNode.getAttrMapRef(
            )["\"constant_values\""].getI()

        assert (
            mode == 0 and constant_values == 0
        )  # For now to make life easy - deal with SYMMETRIC AND REFLECT when time comes
        inputsRef = curNode.getInputsRef()
        inputTensorShapeLi = extraNodeInfoDict[inputsRef[0]][0]
        return (None,
                AST.UninterpFuncCall(
                    extraNodeInfoDict[curNode.getName()][0],
                    TFNodesAST.UninterpFuncCallNames.Pad.name, [
                        AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                        AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
                    ],
                    outputDiffInpDims=1))
예제 #7
0
    def Squeeze(
        graph: Graph.Graph,
        curNode: Graph.Node,
        dictNodeNameToOutVarStr: dict,
        extraNodeInfoDict: dict,
    ):
        inputsRef = curNode.getInputsRef()
        inputTensorShape = extraNodeInfoDict[inputsRef[0]][0]
        inputTensorRank = len(inputTensorShape)

        squeezeDims = curNode.getAttrMapRef()["squeeze_dims"].getList().getILi(
        )
        squeezeDimsRank = len(squeezeDims)

        return (
            None,
            {
                curNode.getName():
                AST.UninterpFuncCall(
                    extraNodeInfoDict[curNode.getName()][0],
                    TFNodesAST.UninterpFuncCallNames.Squeeze.name,
                    list(
                        map(lambda x: AST.Int(x, 32, isSecret=False),
                            squeezeDims)) +
                    [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])],
                )
            },
        )
예제 #8
0
    def Slice(graph: Graph.Graph, curNode: Graph.Node,
              dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 3)
        beginNode = graph.__getitem__(inputsRef[1])
        sizeNode = graph.__getitem__(inputsRef[2])
        assert beginNode.getAttrVal(
            "value"
        ) is not None, "begin {} of Slice node {} has to be a constant".format(
            inputsRef[1], curNode.getName())
        assert sizeNode.getAttrVal(
            "value"
        ) is not None, "size {} of Slice node {} has to be a constant".format(
            inputsRef[2], curNode.getName())
        begin = beginNode.getAttrVal("value").getTensor().getContentAsValArr()
        size = sizeNode.getAttrVal("value").getTensor().getContentAsValArr()
        assert begin is not None
        assert size is not None
        assert len(begin) == len(size)
        subscriptRanges = []
        for i in range(0, len(size)):
            subscriptRanges.append((begin[i], begin[i] + size[i] - 1))

        return (None, {
            curNode.getName():
            AST.Slice(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                      subscriptRanges)
        })
예제 #9
0
    def Mean(graph: Graph.Graph, curNode: Graph.Node,
             dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        attrMapRef = curNode.getAttrMapRef()
        assert (len(inputsRef) == 2)
        keepdims = False
        if ("keep_dims" in attrMapRef):
            keepdims = attrMapRef["keep_dims"].getB()

        reductionAxesNodeName = inputsRef[1]
        redAxesN = graph.__getitem__(reductionAxesNodeName)
        redAxesT = redAxesN.getAttrVal("value").getTensor()
        rank = redAxesT.getShapeRef().getRank()
        if rank != 0:
            reductionAxesList = redAxesT.getContentAsValArr()
        else:
            reductionAxesList = [redAxesT.getConstantVal()]

        curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0]
        return (None, {
            curNode.getName():
            AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                       keepdims, curNodeShapeLi,
                       TFNodesAST.getOperatorsIdx('mean'), reductionAxesList)
        })
예제 #10
0
	def ArgMax(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef) == 2)
		return (None, { curNode.getName() : AST.ArgMax(extraNodeInfoDict[curNode.getName()][0], 
								 AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), 
								 AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), 
								 extraNodeInfoDict[inputsRef[0]][0])})
예제 #11
0
	def Split(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef) == 2)
		axisNodeName = inputsRef[0] # split_dim input. Has to be a constant. We don't support dynamic codegen yet
		axisNode = graph.__getitem__(axisNodeName)
		axisTensor = axisNode.getAttrVal("value").getTensor()
		axis = axisTensor.getConstantVal()
		numSplits = curNode.getAttrVal("num_split").getI()
		inputTensorShape = extraNodeInfoDict[inputsRef[1]][0]
		assert(axis < len(inputTensorShape)) 
		assert(inputTensorShape[axis] % numSplits == 0) #Should perfectly split
		sizeAlongSplitDim = int(inputTensorShape[axis]/numSplits)
		outputAsts = {}
		for i in range(0, numSplits):
			output_name = curNode.getName()
			if i != 0:
				output_name += ":" + str(i)
			subscriptRanges = []
			for j in range(0, len(inputTensorShape)):
				start = 0
				end = inputTensorShape[j] - 1
				if j == axis:
					start = i*sizeAlongSplitDim
					end = start + sizeAlongSplitDim - 1
				subscriptRanges.append((start,end))
			outputAsts[output_name] =  AST.Slice(AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), subscriptRanges)
		return (None, outputAsts)
예제 #12
0
	def ExpandDims(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef) == 2)
		retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
										TFNodesAST.UninterpFuncCallNames.ExpandDims.name, 
									  list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)))
		return (None, { curNode.getName() : retAST})
예제 #13
0
	def Equal(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef) == 2)
		return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
							TFNodesAST.getOperatorsIdx('=='),
							AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
							)})
예제 #14
0
	def Const(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		assert(len(curNode.getInputsRef()) == 0)
		tensor = curNode.getAttrMapRef()["value"].getTensor()
		curNodeDataType = curNode.getAttrMapRef()["dtype"].getDataType()
		curNodeShape = tensor.getShapeRef()[:] #create a different copy to not change the original copy
		
		tensorConstantVal = tensor.getConstantVal()
		if tensorConstantVal is not None:
			# Use uinterpreted call of CreateTensor to create the tensor and fill it with a constant value
			dataPassed = None
			if curNodeDataType == Graph.DataTypeEnum.DT_INT32:
				dataPassed = AST.Int(tensorConstantVal, 32, isSecret=False)
			elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT:
				dataPassed = AST.Float(tensorConstantVal, isSecret=False)
			else:
				assert False

			if (len(curNodeShape) == 0):
				# This is a constant element
				retAST = dataPassed
			else:
				retAST = AST.UninterpFuncCall(curNodeShape,
											TFNodesAST.UninterpFuncCallNames.CreateTensor.name, 
											[dataPassed],
											isSecret=False)
		else:
			# The tensor content is given as byte array. Extract val array from the byte array and create ast.
			if curNodeDataType == Graph.DataTypeEnum.DT_INT32:
				dataPassed = list(map(lambda x: AST.Int(x, 32, isSecret=False), tensor.getContentAsValArr()[:]))
			elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT:
				dataPassed = list(map(lambda x: AST.Float(x, isSecret=False), tensor.getContentAsValArr()[:]))
			else:
				assert False
			retAST = AST.Decl(curNodeShape, None, dataPassed, isSecret=False)
		return (None, { curNode.getName() : retAST})
예제 #15
0
 def StopGradient(graph: Graph.Graph, curNode: Graph.Node,
                  dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     return (None, {
         curNode.getName():
         AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])
     })
예제 #16
0
	def FloorDiv(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef) == 2)
		realDivAST = AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
							TFNodesAST.getOperatorsIdx('./'),
							AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
							)
		return (None, { curNode.getName() : AST.Func(TFNodesAST.getOperatorsIdx('floor'), realDivAST)})
예제 #17
0
 def FusedBatchNorm(graph: Graph.Graph, curNode: Graph.Node,
                    dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     return (None,
             AST.FusedBatchNorm(
                 AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                 AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
                 AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]),
             ))
예제 #18
0
	def ZerosLike(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef)==1)
		curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType()
		assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID)
		retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
									TFNodesAST.UninterpFuncCallNames.CreateTensor.name, 
									[AST.Int(0, isSecret=False)],
									isSecret=False)
		return (None, { curNode.getName() : retAST})
예제 #19
0
	def Conv3DBackpropInputV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef)==3) #output_shape, filter, input

		stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi()
		assert(stridesUsed[0]==1 and stridesUsed[4]==1)
		strideD = stridesUsed[1]
		strideH = stridesUsed[2]
		strideW = stridesUsed[3]

		filterShape = extraNodeInfoDict[inputsRef[1]][0]
		FD = filterShape[0]
		FH = filterShape[1]
		FW = filterShape[2]

		inputShape = extraNodeInfoDict[inputsRef[2]][0]
		inputD = inputShape[1]
		inputH = inputShape[2]
		inputW = inputShape[3]

		outputShape = extraNodeInfoDict[curNode.getName()][0]
		outputD = outputShape[1]
		outputH = outputShape[2]
		outputW = outputShape[3]

		paddingUsedStr = curNode.getAttrMapRef()["padding"].getS()

		# Important: Using outputH and outputW in the below is not an error!
		#			For convTranspose, the parameters passed in the node are of the conv of which this convTranspose is an inverse.
		#			Which is why the call to helper_findPadding makes sense.
		#			The zPads below are of the conv of which this convTranspose is an inverse.
		[zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = TFNodesAST.helper_findPadding(outputH, outputW, FH, FW, strideH, strideW, paddingUsedStr, imgD = outputD, FD = FD, strideD = strideD)

		options = {}
		options[AST.PaddingKeysDict.FD] = FD
		options[AST.PaddingKeysDict.FH] = FH
		options[AST.PaddingKeysDict.FW] = FW
		options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft
		options[AST.PaddingKeysDict.zPadDRight] = zPadDRight
		options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft
		options[AST.PaddingKeysDict.zPadHRight] = zPadHRight
		options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft
		options[AST.PaddingKeysDict.zPadWRight] = zPadWRight
		options[AST.PaddingKeysDict.strideD] = strideD
		options[AST.PaddingKeysDict.strideH] = strideH
		options[AST.PaddingKeysDict.strideW] = strideW
		options[AST.PaddingKeysDict.ConvDim] = 3
		options[AST.PaddingKeysDict.outputImgD] = outputD
		options[AST.PaddingKeysDict.outputImgH] = outputH
		options[AST.PaddingKeysDict.outputImgW] = outputW
		return (None, { curNode.getName() : AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[2]]),
								TFNodesAST.getOperatorsIdx('#T'),
								AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
								options)})
예제 #20
0
	def Transpose(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef) == 2)
		permNodeName = inputsRef[1]
		# We need to fetch the tensor value of the perm Node
		permNode = graph.__getitem__(permNodeName)
		permTensor = permNode.getAttrVal("value").getTensor()
		permList = permTensor.getContentAsValArr()
		assert(permTensor.getDType().kind == "i")
		assert(permTensor.getShapeRef().getRank() == 1)
		return (None, { curNode.getName() : AST.Transpose(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), permList)})
예제 #21
0
파일: TFNodesAST.py 프로젝트: raina777/EzPC
	def Slice(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef) == 3)
		curNodeDataType = curNode.getAttrMapRef()["\"T\""].getDataType()
		retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
									TFNodesAST.UninterpFuncCallNames.CreateCopy.name, 
									[AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),  # of this
									AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),   # begin idx
									AST.ID(dictNodeNameToOutVarStr[inputsRef[2]])	 # size
									])
		return (None, retAST)
예제 #22
0
 def Tile(graph: Graph.Graph, curNode: Graph.Node,
          dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 2)
     return (None,
             AST.UninterpFuncCall(
                 extraNodeInfoDict[curNode.getName()][0],
                 TFNodesAST.UninterpFuncCallNames.Tile.name, [
                     AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                     AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
                 ]))
예제 #23
0
파일: TFNodesAST.py 프로젝트: raina777/EzPC
	def Mean(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		attrMapRef = curNode.getAttrMapRef()
		assert(len(inputsRef) == 2)
		keepdims = False
		if ("\"keep_dims\"" in attrMapRef):
			keepdims = attrMapRef["\"keep_dims\""].getB()
		curNodeShapeLi = extraNodeInfoDict[curNode.getName()][0]
		return (None, AST.Reduce(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), 
								AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), 
								AST.Int(int(keepdims), 32, isSecret=False),
								curNodeShapeLi,
								TFNodesAST.getOperatorsIdx('mean')))
예제 #24
0
	def ConcatV2(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		N = curNode.getAttrMapRef()["N"].getI()
		assert(len(inputsRef) == N+1) #One extra for axis
		#TODO : Since the axis of concat is constant, therefore, its known here - the input's sizes along that dim should be 
		#		passed as input to the below function.
		#		For now hardcoding.
		retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
									TFNodesAST.UninterpFuncCallNames.Concat.name + str(N) + 'T', 
									list(map(lambda x : AST.ID(dictNodeNameToOutVarStr[x]), inputsRef)),
									outputDiffInpDims=1
									) 
		return (None, { curNode.getName() : retAST})
예제 #25
0
	def Fill(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef) == 2)
		curNodeOutputShape = extraNodeInfoDict[inputsRef[0]][0]
		assert(len(curNodeOutputShape) == 1) #inputsRef[0] denotes a shape and should have a rank of 1
		
		curNodeOutputType = curNode.getAttrMapRef()["T"].getDataType()
		assert(curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID)

		retAST = AST.UninterpFuncCall(extraNodeInfoDict[curNode.getName()][0],
									TFNodesAST.UninterpFuncCallNames.CreateTensor.name, 
									[AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) ],
									isSecret=False)
		return (None, { curNode.getName() : retAST})
예제 #26
0
 def ReadVariableOp(
     graph: Graph.Graph,
     curNode: Graph.Node,
     dictNodeNameToOutVarStr: dict,
     extraNodeInfoDict: dict,
 ):
     inputsRef = curNode.getInputsRef()
     return (
         None,
         {
             curNode.getName():
             AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])
         },
     )
예제 #27
0
    def Conv2D(
        graph: Graph.Graph,
        curNode: Graph.Node,
        dictNodeNameToOutVarStr: dict,
        extraNodeInfoDict: dict,
    ):
        inputsRef = curNode.getInputsRef()
        assert len(inputsRef) == 2

        stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi()
        assert stridesUsed[0] == 1 and stridesUsed[3] == 1
        strideH = stridesUsed[1]
        strideW = stridesUsed[2]

        inputShape = extraNodeInfoDict[inputsRef[0]][0]
        imgH = inputShape[1]
        imgW = inputShape[2]

        filterShape = extraNodeInfoDict[inputsRef[1]][0]
        FH = filterShape[0]
        FW = filterShape[1]

        paddingUsedStr = curNode.getAttrMapRef()["padding"].getS()

        [zPadHLeft, zPadHRight, zPadWLeft,
         zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW,
                                                     strideH, strideW,
                                                     paddingUsedStr)

        options = {}
        options[AST.PaddingKeysDict.FH] = FH
        options[AST.PaddingKeysDict.FW] = FW
        options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft
        options[AST.PaddingKeysDict.zPadHRight] = zPadHRight
        options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft
        options[AST.PaddingKeysDict.zPadWRight] = zPadWRight
        options[AST.PaddingKeysDict.strideH] = strideH
        options[AST.PaddingKeysDict.strideW] = strideW
        return (
            None,
            {
                curNode.getName():
                AST.BOp(
                    AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                    TFNodesAST.getOperatorsIdx("#"),
                    AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
                    options,
                )
            },
        )
예제 #28
0
    def Assign(graph: Graph.Graph, curNode: Graph.Node,
               dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 2)
        curNodeShape = extraNodeInfoDict[curNode.getName()][0]

        #########TODO_TAB : for inference, have commented the copyTensor function calls.
        #### TODO : Hack -- fix this later after discussing with Aseem
        # return (None, AST.UninterpFuncCall(curNodeShape,
        # 									TFNodesAST.UninterpFuncCallNames.CopyTensor.name,
        # 									[AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
        # 									AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])]))

        return (None, None)
예제 #29
0
    def Conv3D(graph: Graph.Graph, curNode: Graph.Node,
               dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 2)

        stridesUsed = curNode.getAttrMapRef()["strides"].getList().getILi()
        assert (stridesUsed[0] == 1 and stridesUsed[4] == 1)
        strideD = stridesUsed[1]
        strideH = stridesUsed[2]
        strideW = stridesUsed[3]

        inputShape = extraNodeInfoDict[inputsRef[0]][0]
        imgD = inputShape[1]
        imgH = inputShape[2]
        imgW = inputShape[3]

        filterShape = extraNodeInfoDict[inputsRef[1]][0]
        FD = filterShape[0]
        FH = filterShape[1]
        FW = filterShape[2]

        paddingUsedStr = curNode.getAttrMapRef()["padding"].getS()

        [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft,
         zPadWRight] = TFNodesAST.helper_findPadding(imgH, imgW, FH, FW,
                                                     strideH, strideW,
                                                     paddingUsedStr, imgD, FD,
                                                     strideD)

        options = {}
        options[AST.PaddingKeysDict.FD] = FD
        options[AST.PaddingKeysDict.FH] = FH
        options[AST.PaddingKeysDict.FW] = FW
        options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft
        options[AST.PaddingKeysDict.zPadDRight] = zPadDRight
        options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft
        options[AST.PaddingKeysDict.zPadHRight] = zPadHRight
        options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft
        options[AST.PaddingKeysDict.zPadWRight] = zPadWRight
        options[AST.PaddingKeysDict.strideD] = strideD
        options[AST.PaddingKeysDict.strideH] = strideH
        options[AST.PaddingKeysDict.strideW] = strideW
        options[AST.PaddingKeysDict.ConvDim] = 3
        return (None, {
            curNode.getName():
            AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                    TFNodesAST.getOperatorsIdx('#'),
                    AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]), options)
        })
예제 #30
0
	def Identity(graph : Graph.Graph, curNode : Graph.Node, dictNodeNameToOutVarStr : dict, extraNodeInfoDict : dict):
		#In SeeDot, J2=J1 creates a new reference for J1 -- so 
		#	the corresponding code in Seedot cannot simply be J2 = J1. 
		#	Instead create a new tensor first and then assign the old one to the new one.
		inputsRef = curNode.getInputsRef()
		assert(len(inputsRef)==1)

		curNodeDataType = curNode.getAttrMapRef()["T"].getDataType()
		assert(curNodeDataType is not Graph.DataTypeEnum.DT_INVALID)
		
		curNodeShape = extraNodeInfoDict[curNode.getName()][0]
		retAST = AST.UninterpFuncCall(curNodeShape,
									TFNodesAST.UninterpFuncCallNames.CreateIdentity.name, 
									[AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])])
		return (None, { curNode.getName() : retAST})