def convert_linear_layer(node, params): """Convert linear layer from onnx node and params.""" # Default Gemm attributes dc = dict( transpose_weight=True, transpose_activation=False, weight_multiplier=1, bias_multiplier=1, ) dc.update(extract_attributes(node)) for attr in node.attribute: if attr.name in ["transA"] and extract_attr_values(attr) != 0: raise NotImplementedError( "Not implemented for attr.name={} and value!=0.".format( attr.name)) kwargs = {} weight, bias = extract_params(params) kwargs["bias"] = bias is not None kwargs["in_features"] = weight.dims[1] kwargs["out_features"] = weight.dims[0] # initialize layer and load weights layer = nn.Linear(**kwargs) load_params(layer, weight, bias) # apply onnx gemm attributes if dc.get("transpose_weight"): layer.weight.data = layer.weight.data.t() layer.weight.data *= dc.get("weight_multiplier") if layer.bias is not None: layer.bias.data *= dc.get("bias_multiplier") return layer
def convert_instance_norm_layer(node, params): kwargs = extract_attributes(node) # Skip input dimension check, not possible before forward pass layer = InstanceNormWrapper torch_params = [torch.from_numpy(numpy_helper.to_array(param)) for param in params] # Initialize layer and load weights layer = layer(torch_params, **kwargs) return layer
def convert_batch_norm_layer(node, params): kwargs = extract_attributes(node) # Skip input dimension check, not possible before forward pass layer = BatchNormWrapper torch_params = [_deserialize_to_torch(param) for param in params] # Initialize layer and load weights layer = layer(torch_params, **kwargs) return layer
def convert_batch_norm_layer(node, params): kwargs = extract_attributes(node) layer = nn.BatchNorm2d() kwargs["num_features"] = params[0].dims[0] # initialize layer and load weights layer = layer(**kwargs) key = ["weight", "bias", "running_mean", "running_var"] for key, value in zip(key, params): getattr(layer, key).data = torch.from_numpy(numpy_helper.to_array(value)) return layer
def convert_batch_norm_layer(node, params): kwargs = extract_attributes(node) layer = BatchNormUnsafe # Input dimension check missing, not possible before forward pass kwargs["num_features"] = params[0].dims[0] # initialize layer and load weights layer = layer(**kwargs) key = ["weight", "bias", "running_mean", "running_var"] for key, value in zip(key, params): getattr(layer, key).data = torch.from_numpy(numpy_helper.to_array(value)) return layer
def convert_layer(node, layer_type, params=None): """Use to convert Conv, MaxPool, AvgPool layers.""" assert layer_type in [ "Conv", "ConvTranspose", "MaxPool", "AvgPool", ], "Incorrect layer type: {}".format(layer_type) kwargs = extract_attributes(node) kernel_size_length = len(kwargs["kernel_size"]) try: layer = getattr(nn, "{}{}d".format(layer_type, kernel_size_length)) except AttributeError: raise ValueError( "Unexpected length of kernel_size dimension: {}".format(kernel_size_length) ) pad_layer = None if params: weight, bias = extract_params(params) kwargs["bias"] = bias is not None kwargs["in_channels"] = weight.dims[1] * kwargs.get("groups", 1) kwargs["out_channels"] = weight.dims[0] if layer_type == "ConvTranspose": kwargs["in_channels"], kwargs["out_channels"] = ( kwargs["out_channels"], kwargs["in_channels"], ) # if padding is a layer, remove from kwargs and prepend later if "padding" in kwargs and isinstance(kwargs["padding"], nn.Module): pad_layer = kwargs.pop("padding") # initialize layer and load weights layer = layer(**kwargs) load_params(layer, weight, bias) else: # initialize operations without parameters (MaxPool, AvgPool, etc.) if layer_type == "MaxPool": kwargs["return_indices"] = True # if padding is a layer, remove from kwargs and prepend later if "padding" in kwargs and isinstance(kwargs["padding"], nn.Module): pad_layer = kwargs.pop("padding") layer = layer(**kwargs) if pad_layer is not None: layer = nn.Sequential(pad_layer, layer) return layer
def convert_instance_norm_layer(node, params): kwargs = extract_attributes(node) # Skips input dimension check, not possible before forward pass layer = nn.InstanceNorm2d() kwargs["num_features"] = params[0].dims[0] # initialize layer and load weights layer = layer(**kwargs) key = ["weight", "bias"] for key, value in zip(key, params): getattr(layer, key).data = torch.from_numpy(numpy_helper.to_array(value)) return layer
def convert_operations(onnx_model, batch_dim=0): """ Convert onnx model operations. Yields onnx's operator_id, opeartor_name and converted pytorch operator. Parameters ---------- onnx_model: onnx.ModelProto Loaded onnx model. batch_dim: int Usually 0 for computer vision models and 1 for NLP models. Returns ------- iterator: (op_id, op_name, op) """ weights = {tensor.name: tensor for tensor in onnx_model.graph.initializer} for i, node in enumerate(onnx_model.graph.node): # extract only useful inputs params = [ weights[par_name] for par_name in node.input if par_name in weights ] if node.op_type == "Conv": op = convert_layer(node, "Conv", params) elif node.op_type == "Relu": op = nn.ReLU(inplace=True) elif node.op_type == "LeakyRelu": op = nn.LeakyReLU(**extract_attributes(node), inplace=True) elif node.op_type == "Sigmoid": op = nn.Sigmoid() elif node.op_type == "MaxPool": op = convert_layer(node, "MaxPool") elif node.op_type == "AveragePool": op = convert_layer(node, "AvgPool") elif node.op_type == "Flatten": op = Flatten(**extract_attributes(node)) elif node.op_type == "Gemm": op = convert_linear_layer(node, params) op.feature_dim = batch_dim + 1 # Necessary for transformers elif node.op_type == "BatchNormalization": op = convert_batch_norm_layer(node, params=params) elif node.op_type == "InstanceNormalization": op = convert_instance_norm_layer(node, params=params) elif node.op_type == "Concat": op = Concat(**extract_attributes(node)) elif node.op_type == "Constant": # 常量OP如何解决的问题 op = value_wrapper( torch.from_numpy(extract_attributes(node)["constant"])) elif node.op_type == "Reshape": shape = list( filter(lambda x: x.name == node.input[1], onnx_model.graph.initializer)) shape = numpy_helper.to_array(shape[0]) if shape else None op = Reshape(tuple(shape)) elif node.op_type == "Shape": op = Shape() elif node.op_type == "Gather": op = Gather(**extract_attributes(node)) elif node.op_type == "Squeeze": op = Squeeze(**extract_attributes(node)) elif node.op_type == "Unsqueeze": op = partial(torch.unsqueeze, **extract_attributes(node)) elif node.op_type == "ConstantOfShape": op = ConstantOfShape(**extract_attributes(node)) elif node.op_type == "Slice": op = Slice(**extract_attributes(node)) elif node.op_type == "Cast": op = Cast(**extract_attributes(node)) elif node.op_type == "Where": op = Where() elif node.op_type == "Equal": op = torch.eq elif node.op_type == "Mul": op = Mul(**extract_attributes(node)) elif node.op_type == "Div": op = torch.true_divide elif node.op_type == "MatMul": if params: weight = torch.from_numpy(numpy_helper.to_array(params[0])) op = nn.Linear(weight.shape[0], weight.shape[1], bias=False) op.weight.data = weight.t() # check if next node Add to add bias next_node = onnx_model.graph.node[i + 1] next_params = [ weights[par_name] for par_name in next_node.input if par_name in weights ] if next_params and next_node.op_type == "Add": bias = torch.from_numpy( numpy_helper.to_array(next_params[0])) op.bias = nn.Parameter(bias) node.output.pop() node.output.extend(next_node.output) onnx_model.graph.node.pop(i + 1) # remove next node else: op = Matmul() elif node.op_type == "Sub": op = torch.sub elif node.op_type == "Pow": op = torch.pow elif node.op_type == "Sqrt": op = torch.sqrt elif node.op_type == "Softmax": op = nn.Softmax(dim=1) elif node.op_type == "Transpose": op = partial(torch.Tensor.permute, **extract_attributes(node)) elif node.op_type == "Split": kwargs = extract_attributes(node) # if the split_size_or_sections is not in node attributes, # the number_of_splits becomes the number of node outputs if "split_size_or_sections" not in kwargs: kwargs["number_of_splits"] = len(node.output) op = Split(**kwargs) elif node.op_type == "ReduceMean": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.mean, **kwargs) elif node.op_type == "Add": op = Add() elif node.op_type == "GlobalAveragePool": op = GlobalAveragePool() elif node.op_type == "ConvTranspose": op = convert_layer(node, "ConvTranspose", params) elif node.op_type == "Identity": op = nn.Identity() elif node.op_type == "Resize": op = Resize(**extract_attributes(node)) elif node.op_type == "Upsample": op = Upsample(**extract_attributes(node)) elif node.op_type == "OneHot": op = OneHot(**extract_attributes(node)) elif node.op_type == "Pad": op = Pad(**extract_attributes(node)) elif node.op_type == "Clip": op = Clamp(**extract_attributes(node)) elif node.op_type == "Tanh": op = torch.tanh elif node.op_type == "Erf": op = torch.erf elif node.op_type == "Log": op = torch.log elif node.op_type == "Exp": op = torch.exp elif node.op_type == "LRN": op = nn.LocalResponseNorm(**extract_attributes(node)) elif node.op_type == "Dropout": op = nn.Dropout(p=1.0) else: op = getattr(torch, node.op_type.lower(), None) if op is None: raise NotImplementedError( "Conversion not implemented for op_type={}.".format( node.op_type)) else: print("Automatic inference of operator: {}".format( node.op_type.lower())) op_name = "{}_{}".format(node.op_type, node.output[0]) op_id = node.output[0] yield op_id, op_name, op
def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=True): """ Convert onnx model operations. Yields onnx's operator_id, operator_name and converted pytorch operator. Parameters ---------- onnx_graph: onnx.GraphProto Loaded onnx model's GraphProto. opset_version: int ONNX model's opset version. batch_dim: int Usually 0 for computer vision models and 1 for NLP models. enable_pruning: bool Track kept/pruned indices between different calls to forward pass. Returns ------- iterator: (op_id, op_name, op) """ weights = {tensor.name: tensor for tensor in onnx_graph.initializer} for i, node in enumerate(onnx_graph.node): # extract only useful inputs params = [weights[par_name] for par_name in node.input if par_name in weights] if node.op_type == "Add": op = Add(feature_dim=batch_dim + 1) # 0 for CV models and 1 for NLP elif node.op_type == "And": op = OperatorWrapper(torch.logical_and) elif node.op_type == "AveragePool": op = convert_layer(node, "AvgPool") elif node.op_type == "BatchNormalization": op = convert_batch_norm_layer(node, params=params) elif node.op_type == "Cast": op = Cast(**extract_attributes(node)) elif node.op_type == "Ceil": op = OperatorWrapper(torch.ceil) elif node.op_type == "Clip": op = Clip(**extract_attributes(node)) elif node.op_type == "Concat": op = partial(torch.cat, **extract_attributes(node)) elif node.op_type == "Constant": op = Constant(**extract_attributes(node)) elif node.op_type == "ConstantOfShape": op = ConstantOfShape(**extract_attributes(node)) elif node.op_type == "Conv": op = convert_layer(node, "Conv", params) elif node.op_type == "ConvTranspose": op = convert_layer(node, "ConvTranspose", params) elif node.op_type == "Div": op = Div() elif node.op_type == "Elu": op = nn.ELU(**extract_attributes(node), inplace=True) elif node.op_type == "Equal": op = OperatorWrapper(torch.eq) elif node.op_type == "Erf": op = OperatorWrapper(torch.erf) elif node.op_type == "Exp": op = OperatorWrapper(torch.exp) elif node.op_type == "Expand": op = Expand() elif node.op_type == "Flatten": op = Flatten(**extract_attributes(node)) op.feature_dim = batch_dim + 1 # Necessary for transformers elif node.op_type == "Floor": op = OperatorWrapper(torch.floor) elif node.op_type == "Gather": op = Gather(**extract_attributes(node)) elif node.op_type == "GatherND": op = GatherND(**extract_attributes(node)) elif node.op_type == "Gemm": op = convert_linear_layer(node, params) elif node.op_type == "GlobalAveragePool": op = GlobalAveragePool() elif node.op_type == "Greater": op = OperatorWrapper(torch.greater) elif node.op_type == "Identity": op = nn.Identity() elif node.op_type == "InstanceNormalization": op = convert_instance_norm_layer(node, params=params) elif node.op_type == "LeakyRelu": op = nn.LeakyReLU(**extract_attributes(node), inplace=True) elif node.op_type == "Less": op = OperatorWrapper(torch.less) elif node.op_type == "Log": op = OperatorWrapper(torch.log) elif node.op_type == "Loop": op = Loop( opset_version=opset_version, batch_dim=batch_dim, **extract_attributes(node), ) elif node.op_type == "LSTM": op = convert_lstm_layer(node, weights) elif node.op_type == "MatMul": if params: weight = torch.from_numpy(numpy_helper.to_array(params[0])) op = nn.Linear(weight.shape[0], weight.shape[1], bias=False) op.weight.data = weight.t() # check if next node Add to add bias next_node = onnx_graph.node[i + 1] next_params = [ weights[par_name] for par_name in next_node.input if par_name in weights ] if next_params and next_node.op_type == "Add": bias = torch.from_numpy(numpy_helper.to_array(next_params[0])) op.bias = nn.Parameter(bias) node.output.pop() node.output.extend(next_node.output) onnx_graph.node.pop(i + 1) # remove next node else: op = MatMul() elif node.op_type == "Max": op = OperatorWrapper(torch.max) elif node.op_type == "MaxPool": op = convert_layer(node, "MaxPool") elif node.op_type == "Min": op = OperatorWrapper(torch.min) elif node.op_type == "Mul": op = OperatorWrapper(torch.mul) elif node.op_type == "NonMaxSuppression": op = NonMaxSuppression(**extract_attributes(node)) elif node.op_type == "Not": op = OperatorWrapper(torch.logical_not) elif node.op_type == "OneHot": op = OneHot(**extract_attributes(node)) elif node.op_type == "Or": op = OperatorWrapper(torch.logical_or) elif node.op_type == "Pad": op = Pad(**extract_attributes(node)) elif node.op_type == "Pow": op = OperatorWrapper(torch.pow) elif node.op_type == "PRelu": op = PRelu() elif node.op_type == "Range": op = Range() elif node.op_type == "Reciprocal": op = OperatorWrapper(torch.reciprocal) elif node.op_type == "ReduceMax": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.max, **kwargs) elif node.op_type == "ReduceMean": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.mean, **kwargs) elif node.op_type == "ReduceMin": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.min, **kwargs) elif node.op_type == "ReduceProd": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.prod, **kwargs) elif node.op_type == "ReduceSum": op = ReduceSum(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "Relu": op = nn.ReLU(inplace=True) elif node.op_type == "Reshape": shape = list( filter(lambda x: x.name == node.input[1], onnx_graph.initializer) ) shape = np.copy(numpy_helper.to_array(shape[0])) if shape else None op = Reshape(enable_pruning, shape) elif node.op_type == "Resize": op = Resize(**extract_attributes(node)) elif node.op_type == "Scatter": op = Scatter(**extract_attributes(node)) elif node.op_type == "ScatterElements": op = ScatterElements(**extract_attributes(node)) elif node.op_type == "ScatterND": op = ScatterND() elif node.op_type == "Shape": op = Shape() elif node.op_type == "Sigmoid": op = nn.Sigmoid() elif node.op_type == "Slice": op = Slice(**extract_attributes(node)) elif node.op_type == "Softmax": kwargs = dict(dim=-1) kwargs.update(extract_attributes(node)) op = nn.Softmax(**kwargs) elif node.op_type == "Softplus": op = nn.Softplus(beta=1) elif node.op_type == "Softsign": op = nn.Softsign() elif node.op_type == "Split": kwargs = extract_attributes(node) # if the split_size_or_sections is not in node attributes, # the number_of_splits becomes the number of node outputs if "split_size_or_sections" not in kwargs: kwargs["number_of_splits"] = len(node.output) op = Split(enable_pruning, **kwargs) elif node.op_type == "Sqrt": op = OperatorWrapper(torch.sqrt) elif node.op_type == "Squeeze": op = Squeeze(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "Sub": op = OperatorWrapper(torch.sub) elif node.op_type == "Tanh": op = OperatorWrapper(torch.tanh) elif node.op_type == "ThresholdedRelu": op = ThresholdedRelu(**extract_attributes(node)) elif node.op_type == "Tile": op = Tile() elif node.op_type == "TopK": op = TopK() elif node.op_type == "Transpose": op = Transpose(**extract_attributes(node)) elif node.op_type == "Unsqueeze": op = Unsqueeze(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "Upsample": op = Upsample(**extract_attributes(node)) elif node.op_type == "Where": op = Where() else: op = getattr(torch, node.op_type.lower(), None) if op is None: raise NotImplementedError( "Conversion not implemented for op_type={}.".format(node.op_type) ) else: print( "Automatic inference of operator: {}".format(node.op_type.lower()) ) op_name = "{}_{}".format(node.op_type, node.output[0]) op_id = node.output[0] yield op_id, op_name, op
def convert_lstm_layer(node, weights): """Convert LSTM layer from onnx node and params.""" params_tuple = extract_and_load_params_lstm(node, weights) (X, W, R, B, sequence_lens, initial_h, initial_c, P) = params_tuple if initial_h is not None: raise NotImplementedError("LSTM initial_h not yet implemented.") if initial_c is not None: raise NotImplementedError("LSTM initial_c not yet implemented.") if P is not None: raise NotImplementedError("LSTM P not yet implemented.") dc = dict( activation_alpha=None, activation_beta=None, activations=None, clip=None, direction="forward", hidden_size=None, input_forget=0, layout=0, ) dc.update(extract_attributes(node)) if dc["activation_alpha"] is not None: raise NotImplementedError( "LSTM activation_alpha {}.".format(dc["activation_alpha"]) ) if dc["activation_beta"] is not None: raise NotImplementedError( "LSTM activation_beta {}.".format(dc["activation_beta"]) ) if dc["activations"] is not None: # TODO allow if torch-compatible activations are set explicitly raise NotImplementedError("LSTM activations {}.".format(dc["activations"])) if dc["clip"] is not None: raise NotImplementedError("LSTM clip {}".format(dc["clip"])) if dc["direction"] not in ("forward", "bidirectional"): raise ValueError("LSTM direction {}.".format(dc["direction"])) if dc["hidden_size"] is None: raise ValueError("LSTM hidden_size is None.") if dc["input_forget"] != 0: raise NotImplementedError("LSTM input_forget {}.".format(dc["input_forget"])) if dc["layout"] != 0: raise NotImplementedError( "LSTM not implemented for layout={}".format(dc["layout"]) ) kwargs = { "input_size": W.shape[2], "hidden_size": dc["hidden_size"], "num_layers": 1, "bias": True, "batch_first": False, "dropout": 0, "bidirectional": dc["direction"] == "bidirectional", } lstm_layer = nn.LSTM(**kwargs) input_size = kwargs["input_size"] hidden_size = kwargs["hidden_size"] num_directions = kwargs["bidirectional"] + 1 num_layers = 1 if kwargs["bidirectional"]: # Set input-hidden weights W_iofc = W.transpose(0, 1).view(4 * hidden_size, num_directions, input_size) for dir_dim, dir_str in [(0, ""), (1, "_reverse")]: W_ifco = torch.cat( tensors=( W_iofc[0:hidden_size, dir_dim, :], W_iofc[2 * hidden_size : 4 * hidden_size, dir_dim, :], W_iofc[hidden_size : 2 * hidden_size, dir_dim, :], ), dim=0, ) getattr(lstm_layer, "weight_ih_l0{}".format(dir_str)).data = W_ifco # Set hidden-hidden weights R_iofc = R.transpose(0, 1).view(4 * hidden_size, num_directions, hidden_size) for dir_dim, dir_str in [(0, ""), (1, "_reverse")]: R_ifco = torch.cat( tensors=( R_iofc[0:hidden_size, dir_dim, :], R_iofc[2 * hidden_size : 4 * hidden_size, dir_dim, :], R_iofc[hidden_size : 2 * hidden_size, dir_dim, :], ), dim=0, ) getattr(lstm_layer, "weight_hh_l0{}".format(dir_str)).data = R_ifco # Set input-hidden biases for dir_dim, dir_str in [(0, ""), (1, "_reverse")]: Wb_iofc = B[dir_dim, 0 : 4 * hidden_size] Wb_ifco = torch.cat( tensors=( Wb_iofc[0:hidden_size], Wb_iofc[2 * hidden_size : 4 * hidden_size], Wb_iofc[hidden_size : 2 * hidden_size], ), dim=0, ) getattr(lstm_layer, "bias_ih_l0{}".format(dir_str)).data = Wb_ifco # Set hidden-hidden biases for dir_dim, dir_str in [(0, ""), (1, "_reverse")]: Rb_iofc = B[dir_dim, 4 * hidden_size :] Rb_ifco = torch.cat( tensors=( Rb_iofc[0:hidden_size], Rb_iofc[2 * hidden_size : 4 * hidden_size], Rb_iofc[hidden_size : 2 * hidden_size], ), dim=0, ) getattr(lstm_layer, "bias_hh_l0{}".format(dir_str)).data = Rb_ifco else: # Set input-hidden weights W_iofc = W.transpose(0, 1).view(4 * hidden_size, input_size) W_ifco = torch.cat( tensors=( W_iofc[0:hidden_size, :], W_iofc[2 * hidden_size : 4 * hidden_size, :], W_iofc[hidden_size : 2 * hidden_size, :], ), dim=0, ) getattr(lstm_layer, "weight_ih_l0").data = W_ifco # Set hidden-hidden weights R_iofc = R.transpose(0, 1).view(4 * hidden_size, hidden_size) R_ifco = torch.cat( tensors=( R_iofc[0:hidden_size, :], R_iofc[2 * hidden_size : 4 * hidden_size, :], R_iofc[hidden_size : 2 * hidden_size, :], ), dim=0, ) getattr(lstm_layer, "weight_hh_l0").data = R_ifco # Set input-hidden biases Wb_iofc = B[0, 0 : 4 * hidden_size] Wb_ifco = torch.cat( tensors=( Wb_iofc[0:hidden_size], Wb_iofc[2 * hidden_size : 4 * hidden_size], Wb_iofc[hidden_size : 2 * hidden_size], ), dim=0, ) getattr(lstm_layer, "bias_ih_l0").data = Wb_ifco # Set hidden-hidden biases Rb_iofc = B[0, 4 * hidden_size :] Rb_ifco = torch.cat( tensors=( Rb_iofc[0:hidden_size], Rb_iofc[2 * hidden_size : 4 * hidden_size], Rb_iofc[hidden_size : 2 * hidden_size], ), dim=0, ) getattr(lstm_layer, "bias_hh_l0").data = Rb_ifco layer = LSTMWrapper(lstm_layer) return layer