예제 #1
0
 def ApplyGradientDescent(graph: Graph.Graph, curNode: Graph.Node,
                          dictNodeNameToOutVarStr: dict,
                          extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 3)
     inputTensor = AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])
     learningRate = AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])
     deltaTensor = AST.ID(dictNodeNameToOutVarStr[inputsRef[2]])
     return (inputTensor,
             AST.BOp(
                 inputTensor, TFNodesAST.getOperatorsIdx('+'),
                 AST.UOp(
                     TFNodesAST.getOperatorsIdx('-'),
                     AST.BOp(learningRate, TFNodesAST.getOperatorsIdx('.*'),
                             deltaTensor))))
예제 #2
0
    def SoftmaxCrossEntropyWithLogits(graph: Graph.Graph, curNode: Graph.Node,
                                      dictNodeNameToOutVarStr: dict,
                                      extraNodeInfoDict: dict):
        # Input1 is logits and Input2 is the one-hot encoding true distribution
        # Calculate softmax on input1 and cross-entropy between that (p(x)) and true-distribution (q(x))
        # Cross-entropy = \summation_x{-q(x)*log(p(x))}
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 2)
        logitsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])
        labelsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])

        # reduce along column to get row-vector
        # TODO : softmax or implement here ?
        retAST = AST.Let(
            AST.ID('temp_softmax'),
            AST.Func(TFNodesAST.getOperatorsIdx('softmax'), logitsInpt), None)
        retAST.expr = AST.Let(
            AST.ID('temp_1'),
            AST.UOp(
                TFNodesAST.getOperatorsIdx('-'),
                AST.Reduce(
                    AST.BOp(
                        labelsInpt, TFNodesAST.getOperatorsIdx('.*'),
                        AST.Func(TFNodesAST.getOperatorsIdx('log'),
                                 AST.ID('temp_softmax'))), 1,
                    TFNodesAST.getOperatorsIdx('+'))), AST.ID('temp_1'))
        return (None, retAST)
예제 #3
0
    def Conv2D(graph: Graph.Graph, curNode: Graph.Node,
               dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 2)

        options = {}
        # TODO : Parse other options and make sure backend is consuming those
        # Other options left to parse include T, data_format, dilations

        paddingUsed = curNode.getAttrMapRef()["\"padding\""].getS()
        if (paddingUsed == "\"SAME\""):
            options["padding"] = 0
        elif (paddingUsed == "\"VALID\""):
            options["padding"] = 1
        else:
            options["padding"] = -1

        stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi()
        options["strides"] = stridesUsed

        return (None,
                AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                        TFNodesAST.getOperatorsIdx('#'),
                        AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]),
                        options))
예제 #4
0
 def BiasAdd(graph: Graph.Graph, curNode: Graph.Node,
             dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 2)
     return (None,
             AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                     TFNodesAST.getOperatorsIdx('+'),
                     AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])))
예제 #5
0
 def FloorDiv(graph: Graph.Graph, curNode: Graph.Node,
              dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 2)
     realDivAST = AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                          TFNodesAST.getOperatorsIdx('./'),
                          AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]))
     return (None, AST.Func(TFNodesAST.getOperatorsIdx('floor'),
                            realDivAST))
예제 #6
0
    def Add(node, value_info, node_name_to_out_var_dict,
            innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        if (DEBUG):
            print(node)
        inputsRef = node.inputs
        assert (len(inputsRef) == 2)

        reshaped_input_name = get_new_var_name(out_var_count)
        reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info,
                                                node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_input, reshaped_input_name,
            mtdAST)
        out_var_count += 1

        reshaped_input_name1 = get_new_var_name(out_var_count)
        reshaped_input1 = get_reshaped_input_ast(inputsRef[1], value_info,
                                                 node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_input1, reshaped_input_name1,
            mtdAST)
        out_var_count += 1

        seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name),
                                    getOperatorsIdx('+'),
                                    AST.ID(reshaped_input_name1))
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        reshaped_output_name = get_new_var_name(out_var_count)
        onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info,
                                                  output_name)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, onnx_output_ast, reshaped_output_name,
            mtdAST)
        out_var_count += 1
        node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name

        if (DEBUG):
            print(node.outputs[0])
            print(onnx_input_shape, onnx_input_shape1, '->',
                  seedot_input_shape, seedot_input_shape1, '->',
                  onnx_output_shape)

        return (innermost_let_ast_node, out_var_count)
예제 #7
0
 def LogSoftmax(graph: Graph.Graph, curNode: Graph.Node,
                dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 1)
     expAST = AST.Func(TFNodesAST.getOperatorsIdx('exp'),
                       AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))
     reduceAST = AST.Reduce(expAST, AST.Int(-1),
                            TFNodesAST.getOperatorsIdx('+'))
     return (None,
             AST.BOp(
                 AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                 TFNodesAST.getOperatorsIdx('+'),
                 AST.UOp(
                     TFNodesAST.getOperatorsIdx('-'),
                     AST.Func(TFNodesAST.getOperatorsIdx('log'),
                              reduceAST))))
예제 #8
0
    def MatMul(graph: Graph.Graph, curNode: Graph.Node,
               dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 2)
        inp1Str = dictNodeNameToOutVarStr[inputsRef[0]]
        inp2Str = dictNodeNameToOutVarStr[inputsRef[1]]
        inp1AST = AST.ID(inp1Str)
        inp2AST = AST.ID(inp2Str)

        attrMapRef = curNode.getAttrMapRef()
        transposeABool = transposeBBool = False
        if ("\"transpose_a\"" in attrMapRef):
            transposeABool = attrMapRef["\"transpose_a\""].getB()
        if ("\"transpose_b\"" in attrMapRef):
            transposeBBool = attrMapRef["\"transpose_b\""].getB()
        if (transposeABool):
            inp1AST = AST.Transp(inp1AST)
        if (transposeBBool):
            inp2AST = AST.Transp(inp2AST)
        return (None, AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx('*'),
                              inp2AST))
예제 #9
0
    def conv3dtranspose(node, value_info, node_name_to_out_var_dict,
                        innermost_let_ast_node, out_var_count, mtdAST):
        inputsRef = node.inputs
        inputShape = value_info[inputsRef[0]][1]
        filterShape = value_info[inputsRef[1]][1]
        stridesUsed = node.attrs['strides']
        outputShape = value_info[node.outputs[0]][1]

        # sometimes there is a bias to be added as well
        assert (len(inputsRef) == 2 or len(inputsRef) == 3)
        assert (len(stridesUsed) == 3)
        assert (value_info[node.inputs[1]][1][2:] == tuple(
            node.attrs['kernel_shape']))
        # verify this order
        [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft,
         zPadWRight] = node.attrs['pads']

        options = {}
        options[AST.PaddingKeysDict.FD] = filterShape[2]
        options[AST.PaddingKeysDict.FH] = filterShape[3]
        options[AST.PaddingKeysDict.FW] = filterShape[4]
        options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft
        options[AST.PaddingKeysDict.zPadDRight] = zPadDRight
        options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft
        options[AST.PaddingKeysDict.zPadHRight] = zPadHRight
        options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft
        options[AST.PaddingKeysDict.zPadWRight] = zPadWRight
        options[AST.PaddingKeysDict.strideD] = stridesUsed[0]
        options[AST.PaddingKeysDict.strideH] = stridesUsed[1]
        options[AST.PaddingKeysDict.strideW] = stridesUsed[2]
        options[AST.PaddingKeysDict.ConvDim] = 3
        options[AST.PaddingKeysDict.outputImgD] = outputShape[2]
        options[AST.PaddingKeysDict.outputImgH] = outputShape[3]
        options[AST.PaddingKeysDict.outputImgW] = outputShape[4]

        assert (inputShape[1] == filterShape[0])
        # For Input:
        # [N, CI, D, H, W] is the Onnx order it should be changed to
        # [N, D, H, W, CI] order

        reshaped_input_name = get_new_var_name(out_var_count)
        reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info,
                                                node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_input, reshaped_input_name,
            mtdAST)
        out_var_count += 1
        # For filter:
        # [CI, CO, FD, FH, FW] is the Onnx order it should be changed to
        # [FD, FH, FW, CI1, CO] order
        reshaped_filter_name = get_new_var_name(out_var_count)
        reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info,
                                                  node_name_to_out_var_dict)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, reshaped_filter, reshaped_filter_name,
            mtdAST)
        out_var_count += 1

        seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name),
                                    getOperatorsIdx('#T'),
                                    AST.ID(reshaped_filter_name), options)
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        # If there is bias to be added then reshape and add it
        if (len(inputsRef) == 3):
            biasShape = value_info[inputsRef[2]][1]
            reshaped_bias_name = get_new_var_name(out_var_count)
            reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info,
                                                  node_name_to_out_var_dict, 3)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, reshaped_bias, reshaped_bias_name,
                mtdAST)
            out_var_count += 1

            seedot_output_ast = AST.BOp(AST.ID(output_name),
                                        getOperatorsIdx('+'),
                                        AST.ID(reshaped_bias_name), options)
            output_name = get_new_var_name(out_var_count)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
            out_var_count += 1

        return (innermost_let_ast_node, out_var_count, output_name)