예제 #1
0
    def SoftmaxCrossEntropyWithLogits(graph: Graph.Graph, curNode: Graph.Node,
                                      dictNodeNameToOutVarStr: dict,
                                      extraNodeInfoDict: dict):
        # Input1 is logits and Input2 is the one-hot encoding true distribution
        # Calculate softmax on input1 and cross-entropy between that (p(x)) and true-distribution (q(x))
        # Cross-entropy = \summation_x{-q(x)*log(p(x))}
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 2)
        logitsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])
        labelsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])

        # reduce along column to get row-vector
        # TODO : softmax or implement here ?
        retAST = AST.Let(
            AST.ID('temp_softmax'),
            AST.Func(TFNodesAST.getOperatorsIdx('softmax'), logitsInpt), None)
        retAST.expr = AST.Let(
            AST.ID('temp_1'),
            AST.UOp(
                TFNodesAST.getOperatorsIdx('-'),
                AST.Reduce(
                    AST.BOp(
                        labelsInpt, TFNodesAST.getOperatorsIdx('.*'),
                        AST.Func(TFNodesAST.getOperatorsIdx('log'),
                                 AST.ID('temp_softmax'))), 1,
                    TFNodesAST.getOperatorsIdx('+'))), AST.ID('temp_1'))
        return (None, retAST)
예제 #2
0
def process_input_variables(program, innermost_let_ast_node,
                            node_name_to_out_var_dict, out_var_count, mtdAST,
                            graph_def, value_info):
    node = graph_def.input[0]
    curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict)
    mtdForCurAST = {
        AST.ASTNode.mtdKeyTFOpName: 'Input',
        AST.ASTNode.mtdKeyTFNodeName: node.name
    }
    cur_out_var_ast_node = AST.ID(node.name)

    if program:
        assert (type(innermost_let_ast_node) is AST.Let)
        newNode = AST.Let(node.name, curAst, cur_out_var_ast_node)
        # mtdAST.visit(newNode, mtdForCurAST)
        # Updating the innermost Let AST node and the expression for previous Let Node
        innermost_let_ast_node.expr = newNode
        innermost_let_ast_node = newNode
    else:
        innermost_let_ast_node = AST.Let(node.name, curAst,
                                         cur_out_var_ast_node)
        # mtdAST.visit(innermost_let_ast_node, mtdForCurAST)
        innermost_let_ast_node.depth = 0
        program = innermost_let_ast_node

    node_name_to_out_var_dict[node.name] = node.name

    for node in graph_def.initializer:
        if (DEBUG):
            print("Node information")
            print(node)

        curAst = ONNXNodesAST.Input(node, value_info,
                                    node_name_to_out_var_dict, node)
        mtdForCurAST = {
            AST.ASTNode.mtdKeyTFOpName: 'Input',
            AST.ASTNode.mtdKeyTFNodeName: node.name
        }
        if (curAst is None):
            continue

        cur_out_var_ast_node = AST.ID(node.name)

        if program:
            assert (type(innermost_let_ast_node) is AST.Let)
            newNode = AST.Let(node.name, curAst, cur_out_var_ast_node)
            # mtdAST.visit(newNode, mtdForCurAST)
            # Updating the innermost Let AST node and the expression for previous Let Node
            innermost_let_ast_node.expr = newNode
            innermost_let_ast_node = newNode
        else:
            innermost_let_ast_node = AST.Let(node.name, curAst,
                                             cur_out_var_ast_node)
            # mtdAST.visit(innermost_let_ast_node, mtdForCurAST)
            innermost_let_ast_node.depth = 0
            program = innermost_let_ast_node

        node_name_to_out_var_dict[node.name] = node.name
    return (program, innermost_let_ast_node, out_var_count)
예제 #3
0
파일: astBuilder.py 프로젝트: shas19/EdgeML
    def visitLet(self, ctx: SeeDotParser.LetContext):
        name = ctx.lhs().Id().getText()
        decl = self.visit(ctx.expr(0))
        expr = self.visit(ctx.expr(1))

        # In case it is left splicing we need to visit the left splicing node
        if isinstance(ctx.lhs(), SeeDotParser.LeftSpliceContext):
            leftSplice = self.visit(ctx.lhs())
            return AST.Let(name, decl, expr, leftSplice)

        return AST.Let(name, decl, expr)
예제 #4
0
def update_program_with_new_node(innermost_let_ast_node, new_node,
                                 new_node_name, mtdAST):
    cur_out_var_ast_node = AST.ID(new_node_name)
    new_let_node = AST.Let(new_node_name, new_node, cur_out_var_ast_node)
    # mtdAST.visit(new_let_node, {AST.ASTNode.mtdKeyTFOpName : 'no', AST.ASTNode.mtdKeyTFNodeName : 'no'})
    # Updating the innermost Let AST node and the expression for previous Let Node
    innermost_let_ast_node.expr = new_let_node
    innermost_let_ast_node = new_let_node

    # node_name_to_out_var_dict[node.outputs[0]] = new_node_name
    return innermost_let_ast_node
예제 #5
0
def generateIRCode(graph, extraInfoDict):
    program = None
    innerMostLetASTNode = None
    dictNodeNameToOutVarStr = {}
    outVarCt = 0
    outVarPrefix = "J"
    mtdAST = MtdAST()
    for curNode in graph.getAllNodesRef():
        for curInp in curNode.getInputsRef():
            # Consequence of topological sorting of the TF graph
            assert (curInp in dictNodeNameToOutVarStr)
        (assignedVarAST, curAst) = generateASTForNode(graph, curNode,
                                                      dictNodeNameToOutVarStr,
                                                      extraInfoDict)

        mtdForCurAST = {
            AST.ASTNode.mtdKeyTFOpName: curNode.getOp()[1:-1],
            AST.ASTNode.mtdKeyTFNodeName: curNode.getName()[1:-1]
        }

        if (curAst is None):
            dictNodeNameToOutVarStr[curNode.getName()] = None
            continue
        curOutVarStr = outVarPrefix + str(outVarCt)
        curOutVarAstNode = (assignedVarAST
                            if assignedVarAST else AST.ID(curOutVarStr))
        if program:
            assert (type(innerMostLetASTNode) is AST.Let)
            newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode)
            mtdAST.visit(newNode, mtdForCurAST)
            innerMostLetASTNode.expr = newNode
            innerMostLetASTNode = newNode
        else:
            innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst,
                                          curOutVarAstNode)
            mtdAST.visit(innerMostLetASTNode, mtdForCurAST)
            innerMostLetASTNode.depth = 0
            program = innerMostLetASTNode
        dictNodeNameToOutVarStr[curNode.getName()] = curOutVarStr
        outVarCt += 1
    return (program, dictNodeNameToOutVarStr)
예제 #6
0
 def visitLet(self, ctx: SeeDotParser.LetContext):
     name = ctx.Id().getText()
     decl = self.visit(ctx.expr(0))
     expr = self.visit(ctx.expr(1))
     return AST.Let(name, decl, expr)