def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=True): """ Convert onnx model operations. Yields onnx's operator_id, operator_name and converted pytorch operator. Parameters ---------- onnx_graph: onnx.GraphProto Loaded onnx model's GraphProto. opset_version: int ONNX model's opset version. batch_dim: int Usually 0 for computer vision models and 1 for NLP models. enable_pruning: bool Track kept/pruned indices between different calls to forward pass. Returns ------- iterator: (op_id, op_name, op) """ weights = {tensor.name: tensor for tensor in onnx_graph.initializer} for i, node in enumerate(onnx_graph.node): # extract only useful inputs params = [weights[par_name] for par_name in node.input if par_name in weights] if node.op_type == "Add": op = Add(feature_dim=batch_dim + 1) # 0 for CV models and 1 for NLP elif node.op_type == "And": op = OperatorWrapper(torch.logical_and) elif node.op_type == "AveragePool": op = convert_layer(node, "AvgPool") elif node.op_type == "BatchNormalization": op = convert_batch_norm_layer(node, params=params) elif node.op_type == "Cast": op = Cast(**extract_attributes(node)) elif node.op_type == "Ceil": op = OperatorWrapper(torch.ceil) elif node.op_type == "Clip": op = Clip(**extract_attributes(node)) elif node.op_type == "Concat": op = partial(torch.cat, **extract_attributes(node)) elif node.op_type == "Constant": op = Constant(**extract_attributes(node)) elif node.op_type == "ConstantOfShape": op = ConstantOfShape(**extract_attributes(node)) elif node.op_type == "Conv": op = convert_layer(node, "Conv", params) elif node.op_type == "ConvTranspose": op = convert_layer(node, "ConvTranspose", params) elif node.op_type == "Div": op = Div() elif node.op_type == "Elu": op = nn.ELU(**extract_attributes(node), inplace=True) elif node.op_type == "Equal": op = OperatorWrapper(torch.eq) elif node.op_type == "Erf": op = OperatorWrapper(torch.erf) elif node.op_type == "Exp": op = OperatorWrapper(torch.exp) elif node.op_type == "Expand": op = Expand() elif node.op_type == "Flatten": op = Flatten(**extract_attributes(node)) op.feature_dim = batch_dim + 1 # Necessary for transformers elif node.op_type == "Floor": op = OperatorWrapper(torch.floor) elif node.op_type == "Gather": op = Gather(**extract_attributes(node)) elif node.op_type == "GatherND": op = GatherND(**extract_attributes(node)) elif node.op_type == "Gemm": op = convert_linear_layer(node, params) elif node.op_type == "GlobalAveragePool": op = GlobalAveragePool() elif node.op_type == "Greater": op = OperatorWrapper(torch.greater) elif node.op_type == "Identity": op = nn.Identity() elif node.op_type == "InstanceNormalization": op = convert_instance_norm_layer(node, params=params) elif node.op_type == "LeakyRelu": op = nn.LeakyReLU(**extract_attributes(node), inplace=True) elif node.op_type == "Less": op = OperatorWrapper(torch.less) elif node.op_type == "Log": op = OperatorWrapper(torch.log) elif node.op_type == "Loop": op = Loop( opset_version=opset_version, batch_dim=batch_dim, **extract_attributes(node), ) elif node.op_type == "LSTM": op = convert_lstm_layer(node, weights) elif node.op_type == "MatMul": if params: weight = torch.from_numpy(numpy_helper.to_array(params[0])) op = nn.Linear(weight.shape[0], weight.shape[1], bias=False) op.weight.data = weight.t() # check if next node Add to add bias next_node = onnx_graph.node[i + 1] next_params = [ weights[par_name] for par_name in next_node.input if par_name in weights ] if next_params and next_node.op_type == "Add": bias = torch.from_numpy(numpy_helper.to_array(next_params[0])) op.bias = nn.Parameter(bias) node.output.pop() node.output.extend(next_node.output) onnx_graph.node.pop(i + 1) # remove next node else: op = MatMul() elif node.op_type == "Max": op = OperatorWrapper(torch.max) elif node.op_type == "MaxPool": op = convert_layer(node, "MaxPool") elif node.op_type == "Min": op = OperatorWrapper(torch.min) elif node.op_type == "Mul": op = OperatorWrapper(torch.mul) elif node.op_type == "NonMaxSuppression": op = NonMaxSuppression(**extract_attributes(node)) elif node.op_type == "Not": op = OperatorWrapper(torch.logical_not) elif node.op_type == "OneHot": op = OneHot(**extract_attributes(node)) elif node.op_type == "Or": op = OperatorWrapper(torch.logical_or) elif node.op_type == "Pad": op = Pad(**extract_attributes(node)) elif node.op_type == "Pow": op = OperatorWrapper(torch.pow) elif node.op_type == "PRelu": op = PRelu() elif node.op_type == "Range": op = Range() elif node.op_type == "Reciprocal": op = OperatorWrapper(torch.reciprocal) elif node.op_type == "ReduceMax": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.max, **kwargs) elif node.op_type == "ReduceMean": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.mean, **kwargs) elif node.op_type == "ReduceMin": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.min, **kwargs) elif node.op_type == "ReduceProd": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.prod, **kwargs) elif node.op_type == "ReduceSum": op = ReduceSum(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "Relu": op = nn.ReLU(inplace=True) elif node.op_type == "Reshape": shape = list( filter(lambda x: x.name == node.input[1], onnx_graph.initializer) ) shape = np.copy(numpy_helper.to_array(shape[0])) if shape else None op = Reshape(enable_pruning, shape) elif node.op_type == "Resize": op = Resize(**extract_attributes(node)) elif node.op_type == "Scatter": op = Scatter(**extract_attributes(node)) elif node.op_type == "ScatterElements": op = ScatterElements(**extract_attributes(node)) elif node.op_type == "ScatterND": op = ScatterND() elif node.op_type == "Shape": op = Shape() elif node.op_type == "Sigmoid": op = nn.Sigmoid() elif node.op_type == "Slice": op = Slice(**extract_attributes(node)) elif node.op_type == "Softmax": kwargs = dict(dim=-1) kwargs.update(extract_attributes(node)) op = nn.Softmax(**kwargs) elif node.op_type == "Softplus": op = nn.Softplus(beta=1) elif node.op_type == "Softsign": op = nn.Softsign() elif node.op_type == "Split": kwargs = extract_attributes(node) # if the split_size_or_sections is not in node attributes, # the number_of_splits becomes the number of node outputs if "split_size_or_sections" not in kwargs: kwargs["number_of_splits"] = len(node.output) op = Split(enable_pruning, **kwargs) elif node.op_type == "Sqrt": op = OperatorWrapper(torch.sqrt) elif node.op_type == "Squeeze": op = Squeeze(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "Sub": op = OperatorWrapper(torch.sub) elif node.op_type == "Tanh": op = OperatorWrapper(torch.tanh) elif node.op_type == "ThresholdedRelu": op = ThresholdedRelu(**extract_attributes(node)) elif node.op_type == "Tile": op = Tile() elif node.op_type == "TopK": op = TopK() elif node.op_type == "Transpose": op = Transpose(**extract_attributes(node)) elif node.op_type == "Unsqueeze": op = Unsqueeze(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "Upsample": op = Upsample(**extract_attributes(node)) elif node.op_type == "Where": op = Where() else: op = getattr(torch, node.op_type.lower(), None) if op is None: raise NotImplementedError( "Conversion not implemented for op_type={}.".format(node.op_type) ) else: print( "Automatic inference of operator: {}".format(node.op_type.lower()) ) op_name = "{}_{}".format(node.op_type, node.output[0]) op_id = node.output[0] yield op_id, op_name, op
def convert_operations(onnx_model, batch_dim=0): """ Convert onnx model operations. Yields onnx's operator_id, opeartor_name and converted pytorch operator. Parameters ---------- onnx_model: onnx.ModelProto Loaded onnx model. batch_dim: int Usually 0 for computer vision models and 1 for NLP models. Returns ------- iterator: (op_id, op_name, op) """ weights = {tensor.name: tensor for tensor in onnx_model.graph.initializer} for i, node in enumerate(onnx_model.graph.node): # extract only useful inputs params = [ weights[par_name] for par_name in node.input if par_name in weights ] if node.op_type == "Conv": op = convert_layer(node, "Conv", params) elif node.op_type == "Relu": op = nn.ReLU(inplace=True) elif node.op_type == "LeakyRelu": op = nn.LeakyReLU(**extract_attributes(node), inplace=True) elif node.op_type == "Sigmoid": op = nn.Sigmoid() elif node.op_type == "MaxPool": op = convert_layer(node, "MaxPool") elif node.op_type == "AveragePool": op = convert_layer(node, "AvgPool") elif node.op_type == "Flatten": op = Flatten(**extract_attributes(node)) elif node.op_type == "Gemm": op = convert_linear_layer(node, params) op.feature_dim = batch_dim + 1 # Necessary for transformers elif node.op_type == "BatchNormalization": op = convert_batch_norm_layer(node, params=params) elif node.op_type == "InstanceNormalization": op = convert_instance_norm_layer(node, params=params) elif node.op_type == "Concat": op = Concat(**extract_attributes(node)) elif node.op_type == "Constant": # 常量OP如何解决的问题 op = value_wrapper( torch.from_numpy(extract_attributes(node)["constant"])) elif node.op_type == "Reshape": shape = list( filter(lambda x: x.name == node.input[1], onnx_model.graph.initializer)) shape = numpy_helper.to_array(shape[0]) if shape else None op = Reshape(tuple(shape)) elif node.op_type == "Shape": op = Shape() elif node.op_type == "Gather": op = Gather(**extract_attributes(node)) elif node.op_type == "Squeeze": op = Squeeze(**extract_attributes(node)) elif node.op_type == "Unsqueeze": op = partial(torch.unsqueeze, **extract_attributes(node)) elif node.op_type == "ConstantOfShape": op = ConstantOfShape(**extract_attributes(node)) elif node.op_type == "Slice": op = Slice(**extract_attributes(node)) elif node.op_type == "Cast": op = Cast(**extract_attributes(node)) elif node.op_type == "Where": op = Where() elif node.op_type == "Equal": op = torch.eq elif node.op_type == "Mul": op = Mul(**extract_attributes(node)) elif node.op_type == "Div": op = torch.true_divide elif node.op_type == "MatMul": if params: weight = torch.from_numpy(numpy_helper.to_array(params[0])) op = nn.Linear(weight.shape[0], weight.shape[1], bias=False) op.weight.data = weight.t() # check if next node Add to add bias next_node = onnx_model.graph.node[i + 1] next_params = [ weights[par_name] for par_name in next_node.input if par_name in weights ] if next_params and next_node.op_type == "Add": bias = torch.from_numpy( numpy_helper.to_array(next_params[0])) op.bias = nn.Parameter(bias) node.output.pop() node.output.extend(next_node.output) onnx_model.graph.node.pop(i + 1) # remove next node else: op = Matmul() elif node.op_type == "Sub": op = torch.sub elif node.op_type == "Pow": op = torch.pow elif node.op_type == "Sqrt": op = torch.sqrt elif node.op_type == "Softmax": op = nn.Softmax(dim=1) elif node.op_type == "Transpose": op = partial(torch.Tensor.permute, **extract_attributes(node)) elif node.op_type == "Split": kwargs = extract_attributes(node) # if the split_size_or_sections is not in node attributes, # the number_of_splits becomes the number of node outputs if "split_size_or_sections" not in kwargs: kwargs["number_of_splits"] = len(node.output) op = Split(**kwargs) elif node.op_type == "ReduceMean": kwargs = dict(keepdim=True) kwargs.update(extract_attributes(node)) op = partial(torch.mean, **kwargs) elif node.op_type == "Add": op = Add() elif node.op_type == "GlobalAveragePool": op = GlobalAveragePool() elif node.op_type == "ConvTranspose": op = convert_layer(node, "ConvTranspose", params) elif node.op_type == "Identity": op = nn.Identity() elif node.op_type == "Resize": op = Resize(**extract_attributes(node)) elif node.op_type == "Upsample": op = Upsample(**extract_attributes(node)) elif node.op_type == "OneHot": op = OneHot(**extract_attributes(node)) elif node.op_type == "Pad": op = Pad(**extract_attributes(node)) elif node.op_type == "Clip": op = Clamp(**extract_attributes(node)) elif node.op_type == "Tanh": op = torch.tanh elif node.op_type == "Erf": op = torch.erf elif node.op_type == "Log": op = torch.log elif node.op_type == "Exp": op = torch.exp elif node.op_type == "LRN": op = nn.LocalResponseNorm(**extract_attributes(node)) elif node.op_type == "Dropout": op = nn.Dropout(p=1.0) else: op = getattr(torch, node.op_type.lower(), None) if op is None: raise NotImplementedError( "Conversion not implemented for op_type={}.".format( node.op_type)) else: print("Automatic inference of operator: {}".format( node.op_type.lower())) op_name = "{}_{}".format(node.op_type, node.output[0]) op_id = node.output[0] yield op_id, op_name, op
def test_upsample(inp, scales, new_shape): op = Upsample() out = op(inp, scales) assert list(out.shape) == new_shape