def SoftmaxCrossEntropyWithLogits(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): # Input1 is logits and Input2 is the one-hot encoding true distribution # Calculate softmax on input1 and cross-entropy between that (p(x)) and true-distribution (q(x)) # Cross-entropy = \summation_x{-q(x)*log(p(x))} inputsRef = curNode.getInputsRef() assert (len(inputsRef) == 2) logitsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) labelsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]) # reduce along column to get row-vector # TODO : softmax or implement here ? retAST = AST.Let( AST.ID('temp_softmax'), AST.Func(TFNodesAST.getOperatorsIdx('softmax'), logitsInpt), None) retAST.expr = AST.Let( AST.ID('temp_1'), AST.UOp( TFNodesAST.getOperatorsIdx('-'), AST.Reduce( AST.BOp( labelsInpt, TFNodesAST.getOperatorsIdx('.*'), AST.Func(TFNodesAST.getOperatorsIdx('log'), AST.ID('temp_softmax'))), 1, TFNodesAST.getOperatorsIdx('+'))), AST.ID('temp_1')) return (None, retAST)
def LogSoftmax(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): inputsRef = curNode.getInputsRef() assert (len(inputsRef) == 1) expAST = AST.Func(TFNodesAST.getOperatorsIdx('exp'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])) reduceAST = AST.Reduce(expAST, AST.Int(-1), TFNodesAST.getOperatorsIdx('+')) return (None, AST.BOp( AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('+'), AST.UOp( TFNodesAST.getOperatorsIdx('-'), AST.Func(TFNodesAST.getOperatorsIdx('log'), reduceAST))))
def Shape(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): inputsRef = curNode.getInputsRef() assert (len(inputsRef) == 1) return (None, AST.Func(TFNodesAST.getOperatorsIdx('shape'), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])))
def FloorDiv(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): inputsRef = curNode.getInputsRef() assert (len(inputsRef) == 2) realDivAST = AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]), TFNodesAST.getOperatorsIdx('./'), AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])) return (None, AST.Func(TFNodesAST.getOperatorsIdx('floor'), realDivAST))
def ArgMax(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): node = OnnxNode(node) if (DEBUG): print(node) inputsRef = node.inputs seedot_output_ast = AST.Func( SeeDotParser.ARGMAX, AST.ID(node_name_to_out_var_dict[inputsRef[0]])) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 node_name_to_out_var_dict[node.outputs[0]] = output_name return (innermost_let_ast_node, out_var_count)
def Relu(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): node = OnnxNode(node) inputsRef = node.inputs assert (len(inputsRef) == 1) spatial_size = len(value_info[inputsRef[0]][1]) relu_input_name = node_name_to_out_var_dict[inputsRef[0]] if (spatial_size >= 4): relu_input_name = get_new_var_name(out_var_count) reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_input, relu_input_name, mtdAST) out_var_count += 1 seedot_output_ast = AST.Func(SeeDotParser.RELU, AST.ID(relu_input_name)) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 final_output_name = output_name if (spatial_size >= 4): final_output_name = get_new_var_name(out_var_count) onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, onnx_output_ast, final_output_name, mtdAST) out_var_count += 1 node_name_to_out_var_dict[node.outputs[0]] = final_output_name if (DEBUG): print(node.outputs[0]) print(onnx_input_shape, '->', seedot_input_shape, '->', onnx_output_shape) return (innermost_let_ast_node, out_var_count)
def visitFunc(self, ctx: SeeDotParser.FuncContext): op = ctx.specialFunc().getChild(0).symbol.type expr = self.visit(ctx.expr()) return AST.Func(op, expr)