Exemplo n.º 1
0
    def Const(graph: Graph.Graph, curNode: Graph.Node,
              dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        assert (len(curNode.getInputsRef()) == 0)
        tensor = curNode.getAttrMapRef()["\"value\""].getTensor()
        curNodeDataType = curNode.getAttrMapRef()["\"dtype\""].getDataType()
        # create a different copy to not change the original copy
        curNodeShape = tensor.getShapeRef()[:]

        tensorConstantVal = tensor.getConstantVal()
        if tensorConstantVal is not None:
            # Use uinterpreted call of CreateTensor to create the tensor and fill it with a constant value
            dataPassed = None
            if curNodeDataType == Graph.DataTypeEnum.DT_INT32:
                dataPassed = AST.Int(tensorConstantVal, 32)
            elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT:
                dataPassed = AST.Float(tensorConstantVal)
            else:
                assert False

            if (len(curNodeShape) == 0):
                # This is a constant element
                retAST = dataPassed
            else:
                retAST = AST.UninterpFuncCall(
                    curNodeShape,
                    TFNodesAST.UninterpFuncCallNames.CreateTensor.name,
                    [dataPassed],
                    isSecret=False)
        else:
            # The tensor content is given as byte array. Extract val array from the byte array and create ast.
            if curNodeDataType == Graph.DataTypeEnum.DT_INT32:
                dataPassed = list(
                    map(lambda x: AST.Int(x, 32),
                        tensor.getContentAsValArr()[:]))
            elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT:
                dataPassed = list(
                    map(lambda x: AST.Float(x),
                        tensor.getContentAsValArr()[:]))
            else:
                assert False
            retAST = AST.Decl(curNodeShape,
                              None,
                              None,
                              dataPassed,
                              isSecret=False)
        return (None, retAST)
Exemplo n.º 2
0
 def visitFloat(self, ctx: SeeDotParser.FloatContext):
     value = float(ctx.FloatConst().getText())
     return AST.Float(value)