def Const(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): assert (len(curNode.getInputsRef()) == 0) tensor = curNode.getAttrMapRef()["\"value\""].getTensor() curNodeDataType = curNode.getAttrMapRef()["\"dtype\""].getDataType() # create a different copy to not change the original copy curNodeShape = tensor.getShapeRef()[:] tensorConstantVal = tensor.getConstantVal() if tensorConstantVal is not None: # Use uinterpreted call of CreateTensor to create the tensor and fill it with a constant value dataPassed = None if curNodeDataType == Graph.DataTypeEnum.DT_INT32: dataPassed = AST.Int(tensorConstantVal, 32) elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT: dataPassed = AST.Float(tensorConstantVal) else: assert False if (len(curNodeShape) == 0): # This is a constant element retAST = dataPassed else: retAST = AST.UninterpFuncCall( curNodeShape, TFNodesAST.UninterpFuncCallNames.CreateTensor.name, [dataPassed], isSecret=False) else: # The tensor content is given as byte array. Extract val array from the byte array and create ast. if curNodeDataType == Graph.DataTypeEnum.DT_INT32: dataPassed = list( map(lambda x: AST.Int(x, 32), tensor.getContentAsValArr()[:])) elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT: dataPassed = list( map(lambda x: AST.Float(x), tensor.getContentAsValArr()[:])) else: assert False retAST = AST.Decl(curNodeShape, None, None, dataPassed, isSecret=False) return (None, retAST)
def visitFloat(self, ctx: SeeDotParser.FloatContext): value = float(ctx.FloatConst().getText()) return AST.Float(value)