Exemplo n.º 1
0
    def Transpose(node, value_info, node_name_to_out_var_dict,
                  innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        if (DEBUG):
            print(node)

        inputsRef = node.inputs
        assert (len(inputsRef) == 1)

        seedot_output_ast = AST.Transpose(
            AST.ID(node_name_to_out_var_dict[inputsRef[0]]),
            node.attrs['perm'])
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1
        node_name_to_out_var_dict[node.outputs[0]] = output_name

        return (innermost_let_ast_node, out_var_count)
Exemplo n.º 2
0
    def GlobalAveragePool(node, value_info, node_name_to_out_var_dict,
                          innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        if (DEBUG):
            print(node)
        inputsRef = node.inputs
        assert (len(inputsRef) == 1)

        reshaped_input_name = get_new_var_name(out_var_count)
        reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info,
                                                node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_input, reshaped_input_name,
            mtdAST)
        out_var_count += 1

        seedot_output_ast = AST.Pool(
            AST.Pool.PoolType.AvgPool, AST.ID(reshaped_input_name), {
                AST.PaddingKeysDict.FH: value_info[inputsRef[0]][1][2],
                AST.PaddingKeysDict.FW: value_info[inputsRef[0]][1][3],
                AST.PaddingKeysDict.zPadHLeft: 0,
                AST.PaddingKeysDict.zPadHRight: 0,
                AST.PaddingKeysDict.zPadWLeft: 0,
                AST.PaddingKeysDict.zPadWRight: 0,
                AST.PaddingKeysDict.strideH: 1,
                AST.PaddingKeysDict.strideW: 1
            })
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        reshaped_output_name = get_new_var_name(out_var_count)
        onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info,
                                                  output_name)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, onnx_output_ast, reshaped_output_name,
            mtdAST)
        out_var_count += 1
        node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name

        return (innermost_let_ast_node, out_var_count)
Exemplo n.º 3
0
    def AvgPool(graph: Graph.Graph, curNode: Graph.Node,
                dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 1)

        options = {}

        stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
        assert ((stridesUsed[0] == 1) and (stridesUsed[3] == 1))
        strideH = stridesUsed[1]
        strideW = stridesUsed[2]

        kSizeUsed = curNode.getAttrMapRef()["\"ksize\""].getList().getILi()
        assert ((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1))
        kSizeH = kSizeUsed[1]
        kSizeW = kSizeUsed[2]

        paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS()
        zPadH = zPadW = -1
        if (paddingUsedStr == "\"SAME\""):
            zPadH = int((kSizeH - 1) / 2)
            zPadW = int((kSizeW - 1) / 2)
        elif (paddingUsedStr == "\"VALID\""):
            zPadH = zPadW = 0
        else:
            zPadH = zPadW = -1

        inputShape = extraNodeInfoDict[inputsRef[0]][0]
        imgH = inputShape[1]
        imgW = inputShape[2]
        return (None,
                AST.UninterpFuncCall(
                    extraNodeInfoDict[curNode.getName()][0],
                    TFNodesAST.UninterpFuncCallNames.AvgPool.name, [
                        AST.Int(kSizeH, 32),
                        AST.Int(kSizeW, 32),
                        AST.Int(zPadH, 32),
                        AST.Int(zPadW, 32),
                        AST.Int(strideH, 32),
                        AST.Int(strideW, 32),
                        AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])
                    ]))
Exemplo n.º 4
0
    def Reshape(node, value_info, node_name_to_out_var_dict,
                innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        if (DEBUG):
            print(node)

        inputsRef = node.inputs
        assert (len(inputsRef) == 2)
        # print(list(value_info[node.outputs[0]][1]))

        seedot_output_ast = AST.Reshape(
            AST.ID(node_name_to_out_var_dict[inputsRef[0]]),
            list(value_info[node.outputs[0]][1]), None)
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1
        node_name_to_out_var_dict[node.outputs[0]] = output_name

        return (innermost_let_ast_node, out_var_count)
Exemplo n.º 5
0
    def Split(node, value_info, node_name_to_out_var_dict,
              innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        inputsRef = node.inputs
        output_count = len(node.outputs)

        for cur_count in range(output_count):
            seedot_output_ast = AST.UninterpFuncCall(
                list(value_info[node.outputs[cur_count]][1]), 'Split', [
                    AST.ID(node_name_to_out_var_dict[inputsRef[0]]),
                    AST.Int(node.attrs['axis'], 32, False),
                    AST.Int(cur_count, 32, False),
                    AST.Int(output_count, 32, False)
                ])
            output_name = get_new_var_name(out_var_count)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
            out_var_count += 1
            node_name_to_out_var_dict[node.outputs[cur_count]] = output_name

        return (innermost_let_ast_node, out_var_count)
Exemplo n.º 6
0
    def conv3d(node, value_info, node_name_to_out_var_dict,
               innermost_let_ast_node, out_var_count, mtdAST):
        inputsRef = node.inputs
        inputShape = value_info[inputsRef[0]][1]
        filterShape = value_info[inputsRef[1]][1]
        stridesUsed = node.attrs['strides']

        assert (len(inputsRef) == 2 or len(inputsRef) == 3)
        assert (len(stridesUsed) == 3)
        assert (value_info[node.inputs[1]][1][2:] == tuple(
            node.attrs['kernel_shape']))
        # verify this order
        [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft,
         zPadWRight] = node.attrs['pads']

        assert (inputShape[1] == filterShape[1])
        # For Input:
        # [N, CI, D, H, W] is the Onnx order it should be changed to
        # [N, D, H, W, CI] order
        reshaped_input_name = get_new_var_name(out_var_count)
        reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info,
                                                node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_input, reshaped_input_name,
            mtdAST)
        out_var_count += 1

        # For filter:
        # [CO, CI1, FD, FH, FW] is the Onnx order it should be changed to
        # [FD, FH, FW, CI1, CO] order
        reshaped_filter_name = get_new_var_name(out_var_count)
        reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info,
                                                  node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_filter, reshaped_filter_name,
            mtdAST)
        out_var_count += 1

        seedot_output_ast = AST.Bop1(AST.ID(reshaped_input_name),
                                     getOperatorsIdx('#'),
                                     AST.ID(reshaped_filter_name))
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        # If there is bias to be added then reshape and add it
        if (len(inputsRef) == 3):
            reshaped_bias_name = get_new_var_name(out_var_count)
            reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info,
                                                  node_name_to_out_var_dict, 3)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, reshaped_bias, reshaped_bias_name,
                mtdAST)
            out_var_count += 1

            seedot_output_ast = AST.Bop1(AST.ID(output_name),
                                         getOperatorsIdx('+'),
                                         AST.ID(reshaped_bias_name))
            output_name = get_new_var_name(out_var_count)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
            out_var_count += 1

        return (innermost_let_ast_node, out_var_count, output_name)
Exemplo n.º 7
0
    def conv2d(node, value_info, node_name_to_out_var_dict,
               innermost_let_ast_node, out_var_count, mtdAST):
        inputsRef = node.inputs
        inputShape = value_info[inputsRef[0]][1]
        filterShape = value_info[inputsRef[1]][1]

        stridesUsed = node.attrs['strides'] if 'strides' in node.attrs else [
            1, 1
        ]
        group = node.attrs['group'] if 'group' in node.attrs else 1
        padding = node.attrs['pads'] if 'pads' in node.attrs else [0, 0, 0, 0]
        dilation = node.attrs['dilation'] if 'dilation' in node.attrs else [
            1, 1
        ]

        assert (len(inputsRef) == 2 or len(inputsRef) == 3)
        assert (len(stridesUsed) == 2)

        # we assume VALID case when the padding is in string format

        # print(inputShape, filterShape)
        assert (inputShape[1] == filterShape[1] * group)
        # For Input:
        # [N, CI, H, W] is the Onnx order it should be changed to
        # [N, H, W, CI] order
        reshaped_input_name = get_new_var_name(out_var_count)
        reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info,
                                                node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_input, reshaped_input_name,
            mtdAST)
        out_var_count += 1

        # For filter:
        # [CO, CI1, FH, FW] is the Onnx order it should be changed to
        # [FH, FW, CI1, CO] order
        reshaped_filter_name = get_new_var_name(out_var_count)
        reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info,
                                                  node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_filter, reshaped_filter_name,
            mtdAST)
        out_var_count += 1

        seedot_output_ast = AST.Convolution(AST.ID(reshaped_input_name),
                                            AST.ID(reshaped_filter_name),
                                            stridesUsed, padding, dilation,
                                            group)
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        # If there is bias to be added then reshape and add it
        if (len(inputsRef) == 3):
            seedot_output_ast = AST.Bop1(AST.ID(output_name),
                                         SeeDotParser.ADDCIR,
                                         AST.ID(inputsRef[2]))
            output_name = get_new_var_name(out_var_count)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
            out_var_count += 1

        return (innermost_let_ast_node, out_var_count, output_name)
Exemplo n.º 8
0
def get_reshaped_output_ast(onnx_output_name, value_info, output_name):
    onnx_output_shape = list(value_info[onnx_output_name][1])
    onnx_output_order = get_onnx_order(onnx_output_shape)
    return AST.Reshape(AST.ID(output_name), onnx_output_shape,
                       onnx_output_order)
Exemplo n.º 9
0
def get_reshaped_input_ast(input_name, value_info, node_name_to_out_var_dict):
    onnx_input_shape = list(value_info[input_name][1])
    (seedot_input_shape,
     seedot_input_order) = get_seedot_shape_order(onnx_input_shape)
    return AST.Reshape(AST.ID(node_name_to_out_var_dict[input_name]),
                       seedot_input_shape, seedot_input_order)
Exemplo n.º 10
0
    def conv3dtranspose(node, value_info, node_name_to_out_var_dict,
                        innermost_let_ast_node, out_var_count, mtdAST):
        inputsRef = node.inputs
        inputShape = value_info[inputsRef[0]][1]
        filterShape = value_info[inputsRef[1]][1]
        stridesUsed = node.attrs['strides']
        outputShape = value_info[node.outputs[0]][1]

        # sometimes there is a bias to be added as well
        assert (len(inputsRef) == 2 or len(inputsRef) == 3)
        assert (len(stridesUsed) == 3)
        assert (value_info[node.inputs[1]][1][2:] == tuple(
            node.attrs['kernel_shape']))
        # verify this order
        [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft,
         zPadWRight] = node.attrs['pads']

        options = {}
        options[AST.PaddingKeysDict.FD] = filterShape[2]
        options[AST.PaddingKeysDict.FH] = filterShape[3]
        options[AST.PaddingKeysDict.FW] = filterShape[4]
        options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft
        options[AST.PaddingKeysDict.zPadDRight] = zPadDRight
        options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft
        options[AST.PaddingKeysDict.zPadHRight] = zPadHRight
        options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft
        options[AST.PaddingKeysDict.zPadWRight] = zPadWRight
        options[AST.PaddingKeysDict.strideD] = stridesUsed[0]
        options[AST.PaddingKeysDict.strideH] = stridesUsed[1]
        options[AST.PaddingKeysDict.strideW] = stridesUsed[2]
        options[AST.PaddingKeysDict.ConvDim] = 3
        options[AST.PaddingKeysDict.outputImgD] = outputShape[2]
        options[AST.PaddingKeysDict.outputImgH] = outputShape[3]
        options[AST.PaddingKeysDict.outputImgW] = outputShape[4]

        assert (inputShape[1] == filterShape[0])
        # For Input:
        # [N, CI, D, H, W] is the Onnx order it should be changed to
        # [N, D, H, W, CI] order

        reshaped_input_name = get_new_var_name(out_var_count)
        reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info,
                                                node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_input, reshaped_input_name,
            mtdAST)
        out_var_count += 1
        # For filter:
        # [CI, CO, FD, FH, FW] is the Onnx order it should be changed to
        # [FD, FH, FW, CI1, CO] order
        reshaped_filter_name = get_new_var_name(out_var_count)
        reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info,
                                                  node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_filter, reshaped_filter_name,
            mtdAST)
        out_var_count += 1

        seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name),
                                    getOperatorsIdx('#T'),
                                    AST.ID(reshaped_filter_name), options)
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        # If there is bias to be added then reshape and add it
        if (len(inputsRef) == 3):
            biasShape = value_info[inputsRef[2]][1]
            reshaped_bias_name = get_new_var_name(out_var_count)
            reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info,
                                                  node_name_to_out_var_dict, 3)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, reshaped_bias, reshaped_bias_name,
                mtdAST)
            out_var_count += 1

            seedot_output_ast = AST.BOp(AST.ID(output_name),
                                        getOperatorsIdx('+'),
                                        AST.ID(reshaped_bias_name), options)
            output_name = get_new_var_name(out_var_count)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
            out_var_count += 1

        return (innermost_let_ast_node, out_var_count, output_name)
Exemplo n.º 11
0
 def visitId(self, ctx: SeeDotParser.IdContext):
     name = ctx.Id().getText()
     return AST.ID(name)
Exemplo n.º 12
0
 def BroadcastGradientArgs(graph: Graph.Graph, curNode: Graph.Node,
                           dictNodeNameToOutVarStr: dict,
                           extraNodeInfoDict: dict):
     return (None, AST.ID("temp"))  # TODO
Exemplo n.º 13
0
 def StopGradient(graph: Graph.Graph, curNode: Graph.Node,
                  dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 1)
     return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))