Ejemplo n.º 1
0
    def ReduceMean(node, value_info, node_name_to_out_var_dict,
                   innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        inputsRef = node.inputs

        keepdims = node.attrs['keepdims']
        axes = node.attrs['axes']

        # currently handling only this case
        # currently support only 0 case
        assert (keepdims == 0)
        assert (len(axes) == 2)

        seedot_output_ast = AST.UninterpFuncCall(
            value_info[node.outputs[0]][1], 'ReduceMeanO', [
                AST.ID(node_name_to_out_var_dict[inputsRef[0]]),
                AST.Int(axes[0], 32, False),
                AST.Int(axes[1], 32, False)
            ])
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1
        node_name_to_out_var_dict[node.outputs[0]] = output_name
        return (innermost_let_ast_node, out_var_count)
Ejemplo n.º 2
0
 def ReluGrad(graph: Graph.Graph, curNode: Graph.Node,
              dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 2)
     return (None,
             AST.Cond(AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
                      AST.Int(1),
                      AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                      AST.Int(0)))
Ejemplo n.º 3
0
    def Concat(node, value_info, node_name_to_out_var_dict,
               innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        if (DEBUG):
            print(node)
        inputsRef = node.inputs
        N = len(inputsRef)

        inputs = [
            AST.ID(node_name_to_out_var_dict[inputsRef[x]])
            for x in range(0, len(inputsRef))
        ]
        axis = node.attrs['axis']

        seedot_output_ast = AST.UninterpFuncCall(
            list(value_info[node.outputs[0]][1]),
            'Concat' + str(N) + 'T',
            inputs + [AST.Int(axis, 32, False)],
            outputDiffInpDims=1)

        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        node_name_to_out_var_dict[node.outputs[0]] = output_name

        return (innermost_let_ast_node, out_var_count)
Ejemplo n.º 4
0
    def Const(graph: Graph.Graph, curNode: Graph.Node,
              dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        assert (len(curNode.getInputsRef()) == 0)
        tensor = curNode.getAttrMapRef()["\"value\""].getTensor()
        curNodeDataType = curNode.getAttrMapRef()["\"dtype\""].getDataType()
        # create a different copy to not change the original copy
        curNodeShape = tensor.getShapeRef()[:]

        tensorConstantVal = tensor.getConstantVal()
        if tensorConstantVal is not None:
            # Use uinterpreted call of CreateTensor to create the tensor and fill it with a constant value
            dataPassed = None
            if curNodeDataType == Graph.DataTypeEnum.DT_INT32:
                dataPassed = AST.Int(tensorConstantVal, 32)
            elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT:
                dataPassed = AST.Float(tensorConstantVal)
            else:
                assert False

            if (len(curNodeShape) == 0):
                # This is a constant element
                retAST = dataPassed
            else:
                retAST = AST.UninterpFuncCall(
                    curNodeShape,
                    TFNodesAST.UninterpFuncCallNames.CreateTensor.name,
                    [dataPassed],
                    isSecret=False)
        else:
            # The tensor content is given as byte array. Extract val array from the byte array and create ast.
            if curNodeDataType == Graph.DataTypeEnum.DT_INT32:
                dataPassed = list(
                    map(lambda x: AST.Int(x, 32),
                        tensor.getContentAsValArr()[:]))
            elif curNodeDataType == Graph.DataTypeEnum.DT_FLOAT:
                dataPassed = list(
                    map(lambda x: AST.Float(x),
                        tensor.getContentAsValArr()[:]))
            else:
                assert False
            retAST = AST.Decl(curNodeShape,
                              None,
                              None,
                              dataPassed,
                              isSecret=False)
        return (None, retAST)
Ejemplo n.º 5
0
 def ZerosLike(graph: Graph.Graph, curNode: Graph.Node,
               dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 1)
     curNodeOutputType = curNode.getAttrMapRef()["\"T\""].getDataType()
     assert (curNodeOutputType is not Graph.DataTypeEnum.DT_INVALID)
     retAST = AST.UninterpFuncCall(
         extraNodeInfoDict[curNode.getName()][0],
         TFNodesAST.UninterpFuncCallNames.CreateTensor.name, [AST.Int(0)],
         isSecret=False)
     return (None, retAST)
Ejemplo n.º 6
0
    def Split(node, value_info, node_name_to_out_var_dict,
              innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        inputsRef = node.inputs
        output_count = len(node.outputs)

        for cur_count in range(output_count):
            seedot_output_ast = AST.UninterpFuncCall(
                list(value_info[node.outputs[cur_count]][1]), 'Split', [
                    AST.ID(node_name_to_out_var_dict[inputsRef[0]]),
                    AST.Int(node.attrs['axis'], 32, False),
                    AST.Int(cur_count, 32, False),
                    AST.Int(output_count, 32, False)
                ])
            output_name = get_new_var_name(out_var_count)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
            out_var_count += 1
            node_name_to_out_var_dict[node.outputs[cur_count]] = output_name

        return (innermost_let_ast_node, out_var_count)
Ejemplo n.º 7
0
 def Pack(graph: Graph.Graph, curNode: Graph.Node,
          dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     N = curNode.getAttrMapRef()["\"N\""].getI()
     axis = curNode.getAttrMapRef()["\"axis\""].getI()
     assert (len(inputsRef) == N)
     retAST = AST.UninterpFuncCall(
         extraNodeInfoDict[curNode.getName()][0],
         TFNodesAST.UninterpFuncCallNames.Pack.name,
         list(map(lambda x: AST.ID(dictNodeNameToOutVarStr[x]),
                  inputsRef)) + [AST.Int(axis)])
     return (None, retAST)
Ejemplo n.º 8
0
 def TruncatedNormal(graph: Graph.Graph, curNode: Graph.Node,
                     dictNodeNameToOutVarStr: dict,
                     extraNodeInfoDict: dict):
     curNodeDataType = curNode.getAttrMapRef()["\"dtype\""].getDataType()
     assert (curNodeDataType is not Graph.DataTypeEnum.DT_INVALID)
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 1)
     curNodeOutputShape = extraNodeInfoDict[curNode.getName()][0]
     return (None,
             AST.UninterpFuncCall(
                 extraNodeInfoDict[curNode.getName()][0],
                 TFNodesAST.UninterpFuncCallNames.TruncatedNormal.name,
                 [AST.ID(curNodeDataType.name)] +
                 list(map(lambda x: AST.Int(x), curNodeOutputShape)))
             )  # TODO
Ejemplo n.º 9
0
 def LogSoftmax(graph: Graph.Graph, curNode: Graph.Node,
                dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 1)
     expAST = AST.Func(TFNodesAST.getOperatorsIdx('exp'),
                       AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))
     reduceAST = AST.Reduce(expAST, AST.Int(-1),
                            TFNodesAST.getOperatorsIdx('+'))
     return (None,
             AST.BOp(
                 AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                 TFNodesAST.getOperatorsIdx('+'),
                 AST.UOp(
                     TFNodesAST.getOperatorsIdx('-'),
                     AST.Func(TFNodesAST.getOperatorsIdx('log'),
                              reduceAST))))
Ejemplo n.º 10
0
    def Squeeze(graph: Graph.Graph, curNode: Graph.Node,
                dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        # TODO : Do this in somewhat better way
        inputsRef = curNode.getInputsRef()
        inputTensorShape = extraNodeInfoDict[inputsRef[0]][0]
        inputTensorRank = len(inputTensorShape)

        squeezeDims = curNode.getAttrMapRef()["\"squeeze_dims\""].getList(
        ).getILi()
        squeezeDimsRank = len(squeezeDims)

        return (None,
                AST.UninterpFuncCall(
                    extraNodeInfoDict[curNode.getName()][0],
                    TFNodesAST.UninterpFuncCallNames.Squeeze.name,
                    list(map(lambda x: AST.Int(x, 32), squeezeDims)) +
                    [AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])]))
Ejemplo n.º 11
0
 def Slice(graph: Graph.Graph, curNode: Graph.Node,
           dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 3)
     curNodeDataType = curNode.getAttrMapRef()["\"T\""].getDataType()
     curNodeShapeASTLi = list(
         map(lambda x: AST.Int(x), extraNodeInfoDict[curNode.getName()][0]))
     retAST = AST.UninterpFuncCall(
         extraNodeInfoDict[curNode.getName()][0],
         TFNodesAST.UninterpFuncCallNames.CreateCopy.name,
         [
             AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),  # of this
             # begin idx
             AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
             # size
             AST.ID(dictNodeNameToOutVarStr[inputsRef[2]])
         ])
     return (None, retAST)
Ejemplo n.º 12
0
    def AvgPool(graph: Graph.Graph, curNode: Graph.Node,
                dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 1)

        options = {}

        stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
        assert ((stridesUsed[0] == 1) and (stridesUsed[3] == 1))
        strideH = stridesUsed[1]
        strideW = stridesUsed[2]

        kSizeUsed = curNode.getAttrMapRef()["\"ksize\""].getList().getILi()
        assert ((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1))
        kSizeH = kSizeUsed[1]
        kSizeW = kSizeUsed[2]

        paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
        zPadH = zPadW = -1
        if (paddingUsedStr == "\"SAME\""):
            zPadH = int((kSizeH - 1) / 2)
            zPadW = int((kSizeW - 1) / 2)
        elif (paddingUsedStr == "\"VALID\""):
            zPadH = zPadW = 0
        else:
            zPadH = zPadW = -1

        inputShape = extraNodeInfoDict[inputsRef[0]][0]
        imgH = inputShape[1]
        imgW = inputShape[2]
        return (None,
                AST.UninterpFuncCall(
                    extraNodeInfoDict[curNode.getName()][0],
                    TFNodesAST.UninterpFuncCallNames.AvgPool.name, [
                        AST.Int(kSizeH, 32),
                        AST.Int(kSizeW, 32),
                        AST.Int(zPadH, 32),
                        AST.Int(zPadW, 32),
                        AST.Int(strideH, 32),
                        AST.Int(strideW, 32),
                        AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])
                    ]))
Ejemplo n.º 13
0
 def visitInt(self, ctx: SeeDotParser.IntContext):
     value = int(ctx.IntConst().getText())
     return AST.Int(value)