def Transpose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): node = OnnxNode(node) if (DEBUG): print(node) inputsRef = node.inputs assert (len(inputsRef) == 1) seedot_output_ast = AST.Transpose( AST.ID(node_name_to_out_var_dict[inputsRef[0]]), node.attrs['perm']) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 node_name_to_out_var_dict[node.outputs[0]] = output_name return (innermost_let_ast_node, out_var_count)
def GlobalAveragePool(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): node = OnnxNode(node) if (DEBUG): print(node) inputsRef = node.inputs assert (len(inputsRef) == 1) reshaped_input_name = get_new_var_name(out_var_count) reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) out_var_count += 1 seedot_output_ast = AST.Pool( AST.Pool.PoolType.AvgPool, AST.ID(reshaped_input_name), { AST.PaddingKeysDict.FH: value_info[inputsRef[0]][1][2], AST.PaddingKeysDict.FW: value_info[inputsRef[0]][1][3], AST.PaddingKeysDict.zPadHLeft: 0, AST.PaddingKeysDict.zPadHRight: 0, AST.PaddingKeysDict.zPadWLeft: 0, AST.PaddingKeysDict.zPadWRight: 0, AST.PaddingKeysDict.strideH: 1, AST.PaddingKeysDict.strideW: 1 }) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 reshaped_output_name = get_new_var_name(out_var_count) onnx_output_ast = get_reshaped_output_ast(node.outputs[0], value_info, output_name) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, onnx_output_ast, reshaped_output_name, mtdAST) out_var_count += 1 node_name_to_out_var_dict[node.outputs[0]] = reshaped_output_name return (innermost_let_ast_node, out_var_count)
def AvgPool(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): inputsRef = curNode.getInputsRef() assert (len(inputsRef) == 1) options = {} stridesUsed = curNode.getAttrMapRef()["\"strides\""].getList().getILi() assert ((stridesUsed[0] == 1) and (stridesUsed[3] == 1)) strideH = stridesUsed[1] strideW = stridesUsed[2] kSizeUsed = curNode.getAttrMapRef()["\"ksize\""].getList().getILi() assert ((kSizeUsed[0] == 1) and (kSizeUsed[3] == 1)) kSizeH = kSizeUsed[1] kSizeW = kSizeUsed[2] paddingUsedStr = curNode.getAttrMapRef()["\"padding\""].getS() zPadH = zPadW = -1 if (paddingUsedStr == "\"SAME\""): zPadH = int((kSizeH - 1) / 2) zPadW = int((kSizeW - 1) / 2) elif (paddingUsedStr == "\"VALID\""): zPadH = zPadW = 0 else: zPadH = zPadW = -1 inputShape = extraNodeInfoDict[inputsRef[0]][0] imgH = inputShape[1] imgW = inputShape[2] return (None, AST.UninterpFuncCall( extraNodeInfoDict[curNode.getName()][0], TFNodesAST.UninterpFuncCallNames.AvgPool.name, [ AST.Int(kSizeH, 32), AST.Int(kSizeW, 32), AST.Int(zPadH, 32), AST.Int(zPadW, 32), AST.Int(strideH, 32), AST.Int(strideW, 32), AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]) ]))
def Reshape(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): node = OnnxNode(node) if (DEBUG): print(node) inputsRef = node.inputs assert (len(inputsRef) == 2) # print(list(value_info[node.outputs[0]][1])) seedot_output_ast = AST.Reshape( AST.ID(node_name_to_out_var_dict[inputsRef[0]]), list(value_info[node.outputs[0]][1]), None) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 node_name_to_out_var_dict[node.outputs[0]] = output_name return (innermost_let_ast_node, out_var_count)
def Split(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): node = OnnxNode(node) inputsRef = node.inputs output_count = len(node.outputs) for cur_count in range(output_count): seedot_output_ast = AST.UninterpFuncCall( list(value_info[node.outputs[cur_count]][1]), 'Split', [ AST.ID(node_name_to_out_var_dict[inputsRef[0]]), AST.Int(node.attrs['axis'], 32, False), AST.Int(cur_count, 32, False), AST.Int(output_count, 32, False) ]) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 node_name_to_out_var_dict[node.outputs[cur_count]] = output_name return (innermost_let_ast_node, out_var_count)
def conv3d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): inputsRef = node.inputs inputShape = value_info[inputsRef[0]][1] filterShape = value_info[inputsRef[1]][1] stridesUsed = node.attrs['strides'] assert (len(inputsRef) == 2 or len(inputsRef) == 3) assert (len(stridesUsed) == 3) assert (value_info[node.inputs[1]][1][2:] == tuple( node.attrs['kernel_shape'])) # verify this order [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] assert (inputShape[1] == filterShape[1]) # For Input: # [N, CI, D, H, W] is the Onnx order it should be changed to # [N, D, H, W, CI] order reshaped_input_name = get_new_var_name(out_var_count) reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) out_var_count += 1 # For filter: # [CO, CI1, FD, FH, FW] is the Onnx order it should be changed to # [FD, FH, FW, CI1, CO] order reshaped_filter_name = get_new_var_name(out_var_count) reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) out_var_count += 1 seedot_output_ast = AST.Bop1(AST.ID(reshaped_input_name), getOperatorsIdx('#'), AST.ID(reshaped_filter_name)) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 # If there is bias to be added then reshape and add it if (len(inputsRef) == 3): reshaped_bias_name = get_new_var_name(out_var_count) reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 3) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) out_var_count += 1 seedot_output_ast = AST.Bop1(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name)) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 return (innermost_let_ast_node, out_var_count, output_name)
def conv2d(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): inputsRef = node.inputs inputShape = value_info[inputsRef[0]][1] filterShape = value_info[inputsRef[1]][1] stridesUsed = node.attrs['strides'] if 'strides' in node.attrs else [ 1, 1 ] group = node.attrs['group'] if 'group' in node.attrs else 1 padding = node.attrs['pads'] if 'pads' in node.attrs else [0, 0, 0, 0] dilation = node.attrs['dilation'] if 'dilation' in node.attrs else [ 1, 1 ] assert (len(inputsRef) == 2 or len(inputsRef) == 3) assert (len(stridesUsed) == 2) # we assume VALID case when the padding is in string format # print(inputShape, filterShape) assert (inputShape[1] == filterShape[1] * group) # For Input: # [N, CI, H, W] is the Onnx order it should be changed to # [N, H, W, CI] order reshaped_input_name = get_new_var_name(out_var_count) reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) out_var_count += 1 # For filter: # [CO, CI1, FH, FW] is the Onnx order it should be changed to # [FH, FW, CI1, CO] order reshaped_filter_name = get_new_var_name(out_var_count) reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) out_var_count += 1 seedot_output_ast = AST.Convolution(AST.ID(reshaped_input_name), AST.ID(reshaped_filter_name), stridesUsed, padding, dilation, group) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 # If there is bias to be added then reshape and add it if (len(inputsRef) == 3): seedot_output_ast = AST.Bop1(AST.ID(output_name), SeeDotParser.ADDCIR, AST.ID(inputsRef[2])) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 return (innermost_let_ast_node, out_var_count, output_name)
def get_reshaped_output_ast(onnx_output_name, value_info, output_name): onnx_output_shape = list(value_info[onnx_output_name][1]) onnx_output_order = get_onnx_order(onnx_output_shape) return AST.Reshape(AST.ID(output_name), onnx_output_shape, onnx_output_order)
def get_reshaped_input_ast(input_name, value_info, node_name_to_out_var_dict): onnx_input_shape = list(value_info[input_name][1]) (seedot_input_shape, seedot_input_order) = get_seedot_shape_order(onnx_input_shape) return AST.Reshape(AST.ID(node_name_to_out_var_dict[input_name]), seedot_input_shape, seedot_input_order)
def conv3dtranspose(node, value_info, node_name_to_out_var_dict, innermost_let_ast_node, out_var_count, mtdAST): inputsRef = node.inputs inputShape = value_info[inputsRef[0]][1] filterShape = value_info[inputsRef[1]][1] stridesUsed = node.attrs['strides'] outputShape = value_info[node.outputs[0]][1] # sometimes there is a bias to be added as well assert (len(inputsRef) == 2 or len(inputsRef) == 3) assert (len(stridesUsed) == 3) assert (value_info[node.inputs[1]][1][2:] == tuple( node.attrs['kernel_shape'])) # verify this order [zPadDLeft, zPadDRight, zPadHLeft, zPadHRight, zPadWLeft, zPadWRight] = node.attrs['pads'] options = {} options[AST.PaddingKeysDict.FD] = filterShape[2] options[AST.PaddingKeysDict.FH] = filterShape[3] options[AST.PaddingKeysDict.FW] = filterShape[4] options[AST.PaddingKeysDict.zPadDLeft] = zPadDLeft options[AST.PaddingKeysDict.zPadDRight] = zPadDRight options[AST.PaddingKeysDict.zPadHLeft] = zPadHLeft options[AST.PaddingKeysDict.zPadHRight] = zPadHRight options[AST.PaddingKeysDict.zPadWLeft] = zPadWLeft options[AST.PaddingKeysDict.zPadWRight] = zPadWRight options[AST.PaddingKeysDict.strideD] = stridesUsed[0] options[AST.PaddingKeysDict.strideH] = stridesUsed[1] options[AST.PaddingKeysDict.strideW] = stridesUsed[2] options[AST.PaddingKeysDict.ConvDim] = 3 options[AST.PaddingKeysDict.outputImgD] = outputShape[2] options[AST.PaddingKeysDict.outputImgH] = outputShape[3] options[AST.PaddingKeysDict.outputImgW] = outputShape[4] assert (inputShape[1] == filterShape[0]) # For Input: # [N, CI, D, H, W] is the Onnx order it should be changed to # [N, D, H, W, CI] order reshaped_input_name = get_new_var_name(out_var_count) reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info, node_name_to_out_var_dict) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_input, reshaped_input_name, mtdAST) out_var_count += 1 # For filter: # [CI, CO, FD, FH, FW] is the Onnx order it should be changed to # [FD, FH, FW, CI1, CO] order reshaped_filter_name = get_new_var_name(out_var_count) reshaped_filter = get_reshaped_filter_ast(inputsRef[1], value_info, node_name_to_out_var_dict) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_filter, reshaped_filter_name, mtdAST) out_var_count += 1 seedot_output_ast = AST.BOp(AST.ID(reshaped_input_name), getOperatorsIdx('#T'), AST.ID(reshaped_filter_name), options) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 # If there is bias to be added then reshape and add it if (len(inputsRef) == 3): biasShape = value_info[inputsRef[2]][1] reshaped_bias_name = get_new_var_name(out_var_count) reshaped_bias = get_reshaped_bias_ast(inputsRef[2], value_info, node_name_to_out_var_dict, 3) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, reshaped_bias, reshaped_bias_name, mtdAST) out_var_count += 1 seedot_output_ast = AST.BOp(AST.ID(output_name), getOperatorsIdx('+'), AST.ID(reshaped_bias_name), options) output_name = get_new_var_name(out_var_count) innermost_let_ast_node = update_program_with_new_node( innermost_let_ast_node, seedot_output_ast, output_name, mtdAST) out_var_count += 1 return (innermost_let_ast_node, out_var_count, output_name)
def visitId(self, ctx: SeeDotParser.IdContext): name = ctx.Id().getText() return AST.ID(name)
def BroadcastGradientArgs(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): return (None, AST.ID("temp")) # TODO
def StopGradient(graph: Graph.Graph, curNode: Graph.Node, dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict): inputsRef = curNode.getInputsRef() assert (len(inputsRef) == 1) return (None, AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))