def SoftmaxCrossEntropyWithLogits(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): # Input1 is logits and Input2 is the one-hot encoding true distribution # Calculate softmax on input1 and cross-entropy between that (p(x)) and true-distribution (q(x)) # Cross-entropy = \summation_x{-q(x)*log(p(x))} inputsRef = curNode.getInputsRef() assert (len(inputsRef) == 2) logitsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) labelsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) # reduce along column to get row-vector # TODO : softmax or implement here ? retAST = AST.Let( AST.ID('temp_softmax'), AST.Func(TFNodesAST.getOperatorsIdx('softmax'), logitsInpt), None) retAST.expr = AST.Let( AST.ID('temp_1'), AST.UOp( TFNodesAST.getOperatorsIdx('-'), AST.Reduce( AST.BOp( labelsInpt, TFNodesAST.getOperatorsIdx('.*'), AST.Func(TFNodesAST.getOperatorsIdx('log'), AST.ID('temp_softmax'))), 1, TFNodesAST.getOperatorsIdx('+'))), AST.ID('temp_1')) return (None, retAST)
def process_input_variables(program, innermost_let_ast_node, node_name_to_out_var_dict, out_var_count, mtdAST, graph_def, value_info): node = graph_def.input[0] curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict) mtdForCurAST = { AST.ASTNode.mtdKeyTFOpName: 'Input', AST.ASTNode.mtdKeyTFNodeName: node.name } cur_out_var_ast_node = AST.ID(node.name) if program: assert (type(innermost_let_ast_node) is AST.Let) newNode = AST.Let(node.name, curAst, cur_out_var_ast_node) # mtdAST.visit(newNode, mtdForCurAST) # Updating the innermost Let AST node and the expression for previous Let Node innermost_let_ast_node.expr = newNode innermost_let_ast_node = newNode else: innermost_let_ast_node = AST.Let(node.name, curAst, cur_out_var_ast_node) # mtdAST.visit(innermost_let_ast_node, mtdForCurAST) innermost_let_ast_node.depth = 0 program = innermost_let_ast_node node_name_to_out_var_dict[node.name] = node.name for node in graph_def.initializer: if (DEBUG): print("Node information") print(node) curAst = ONNXNodesAST.Input(node, value_info, node_name_to_out_var_dict, node) mtdForCurAST = { AST.ASTNode.mtdKeyTFOpName: 'Input', AST.ASTNode.mtdKeyTFNodeName: node.name } if (curAst is None): continue cur_out_var_ast_node = AST.ID(node.name) if program: assert (type(innermost_let_ast_node) is AST.Let) newNode = AST.Let(node.name, curAst, cur_out_var_ast_node) # mtdAST.visit(newNode, mtdForCurAST) # Updating the innermost Let AST node and the expression for previous Let Node innermost_let_ast_node.expr = newNode innermost_let_ast_node = newNode else: innermost_let_ast_node = AST.Let(node.name, curAst, cur_out_var_ast_node) # mtdAST.visit(innermost_let_ast_node, mtdForCurAST) innermost_let_ast_node.depth = 0 program = innermost_let_ast_node node_name_to_out_var_dict[node.name] = node.name return (program, innermost_let_ast_node, out_var_count)
def visitLet(self, ctx: SeeDotParser.LetContext): name = ctx.lhs().Id().getText() decl = self.visit(ctx.expr(0)) expr = self.visit(ctx.expr(1)) # In case it is left splicing we need to visit the left splicing node if isinstance(ctx.lhs(), SeeDotParser.LeftSpliceContext): leftSplice = self.visit(ctx.lhs()) return AST.Let(name, decl, expr, leftSplice) return AST.Let(name, decl, expr)
def update_program_with_new_node(innermost_let_ast_node, new_node, new_node_name, mtdAST): cur_out_var_ast_node = AST.ID(new_node_name) new_let_node = AST.Let(new_node_name, new_node, cur_out_var_ast_node) # mtdAST.visit(new_let_node, {AST.ASTNode.mtdKeyTFOpName : 'no', AST.ASTNode.mtdKeyTFNodeName : 'no'}) # Updating the innermost Let AST node and the expression for previous Let Node innermost_let_ast_node.expr = new_let_node innermost_let_ast_node = new_let_node # node_name_to_out_var_dict[node.outputs[0]] = new_node_name return innermost_let_ast_node
def generateIRCode(graph, extraInfoDict): program = None innerMostLetASTNode = None dictNodeNameToOutVarStr = {} outVarCt = 0 outVarPrefix = "J" mtdAST = MtdAST() for curNode in graph.getAllNodesRef(): for curInp in curNode.getInputsRef(): # Consequence of topological sorting of the TF graph assert (curInp in dictNodeNameToOutVarStr) (assignedVarAST, curAst) = generateASTForNode(graph, curNode, dictNodeNameToOutVarStr, extraInfoDict) mtdForCurAST = { AST.ASTNode.mtdKeyTFOpName: curNode.getOp()[1:-1], AST.ASTNode.mtdKeyTFNodeName: curNode.getName()[1:-1] } if (curAst is None): dictNodeNameToOutVarStr[curNode.getName()] = None continue curOutVarStr = outVarPrefix + str(outVarCt) curOutVarAstNode = (assignedVarAST if assignedVarAST else AST.ID(curOutVarStr)) if program: assert (type(innerMostLetASTNode) is AST.Let) newNode = AST.Let(curOutVarAstNode, curAst, curOutVarAstNode) mtdAST.visit(newNode, mtdForCurAST) innerMostLetASTNode.expr = newNode innerMostLetASTNode = newNode else: innerMostLetASTNode = AST.Let(AST.ID(curOutVarStr), curAst, curOutVarAstNode) mtdAST.visit(innerMostLetASTNode, mtdForCurAST) innerMostLetASTNode.depth = 0 program = innerMostLetASTNode dictNodeNameToOutVarStr[curNode.getName()] = curOutVarStr outVarCt += 1 return (program, dictNodeNameToOutVarStr)
def visitLet(self, ctx: SeeDotParser.LetContext): name = ctx.Id().getText() decl = self.visit(ctx.expr(0)) expr = self.visit(ctx.expr(1)) return AST.Let(name, decl, expr)