Beispiel #1
0
    def Gemm(node, value_info, node_name_to_out_var_dict,
             innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        if (DEBUG):
            print(node)
        inputsRef = node.inputs
        assert (len(inputsRef) == 3)
        input1AST = AST.ID(node_name_to_out_var_dict[inputsRef[0]])
        input2AST = AST.ID(node_name_to_out_var_dict[inputsRef[1]])

        if ('transA' in node.attrs and node.attrs['transA']):
            input1AST = AST.Transp(input1AST)
        if ('transB' in node.attrs and node.attrs['transB']):
            input2AST = AST.Transp(input2AST)

        # W*x + b
        seedot_output_ast = AST.Bop1(
            AST.Bop1(input1AST, SeeDotParser.MUL,
                     input2AST), SeeDotParser.ADDCIR,
            AST.ID(node_name_to_out_var_dict[inputsRef[2]]))
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        node_name_to_out_var_dict[node.outputs[0]] = output_name

        return (innermost_let_ast_node, out_var_count)
Beispiel #2
0
    def MatMul(graph: Graph.Graph, curNode: Graph.Node,
               dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 2)
        inp1Str = dictNodeNameToOutVarStr[inputsRef[0]]
        inp2Str = dictNodeNameToOutVarStr[inputsRef[1]]
        inp1AST = AST.ID(inp1Str)
        inp2AST = AST.ID(inp2Str)

        attrMapRef = curNode.getAttrMapRef()
        transposeABool = transposeBBool = False
        if ("\"transpose_a\"" in attrMapRef):
            transposeABool = attrMapRef["\"transpose_a\""].getB()
        if ("\"transpose_b\"" in attrMapRef):
            transposeBBool = attrMapRef["\"transpose_b\""].getB()
        if (transposeABool):
            inp1AST = AST.Transp(inp1AST)
        if (transposeBBool):
            inp2AST = AST.Transp(inp2AST)
        return (None, AST.BOp(inp1AST, TFNodesAST.getOperatorsIdx('*'),
                              inp2AST))
Beispiel #3
0
 def visitTransp(self, ctx: SeeDotParser.TranspContext):
     expr = self.visit(ctx.expr())
     return AST.Transp(expr)