def test_get_attributes(): attributes = { "kernel": [3, 3], "padding": [1, 1, 1, 1], } node = make_node("Conv", ["X"], ["Y"], **attributes) assert get_node_attributes(node) == attributes
def _make_conv_layer_info( node: NodeProto, graph: ONNXGraph, execution_order: int, ) -> Optional[NodeProto]: param = ModelInfo._get_node_param_array(node, graph) if param is None: return attributes = get_node_attributes(node) kernel_shape = attributes.get("kernel_shape", list(param.shape[2:])) groups = attributes.get("group", 1) stride = attributes.get("strides", [1] * (len(param.shape) - 2)) padding = attributes.get("pads", [0, 0] * (len(param.shape) - 2)) return LayerInfo.conv_layer( name=node.input[1], in_channels=param.shape[1] * groups, out_channels=param.shape[0], kernel_shape=kernel_shape, bias=len(node.input) > 2, groups=groups, stride=stride, padding=padding, execution_order=execution_order, attributes=dict(sparsity=_param_sparsity(param)), )
def _is_group_conv(node: onnx.NodeProto) -> bool: if not node.op_type == "Conv": return False attrs = get_node_attributes(node) groups = attrs.get("group", 1) try: return int(groups) != 1 except Exception: return False
def _make_gemm_layer_info( node: NodeProto, graph: ONNXGraph, execution_order: int, ) -> Optional[NodeProto]: param = ModelInfo._get_node_param_array(node, graph) if param is None: return attributes = get_node_attributes(node) if attributes.get("transB", 0) != 0: # ensure that param shape is (in_channels, out_channels) param = param.transpose() return LayerInfo.linear_layer( name=node.input[1], in_channels=param.shape[0], out_channels=param.shape[-1], bias=len(node.input) > 2, execution_order=execution_order, attributes=dict(sparsity=_param_sparsity(param)), )
def _convert_quantizable_gemm( model: ModelProto, gemm_node: NodeProto, input_quantize_node: NodeProto, weight_dequantize_node: NodeProto, weight_quantize_node: NodeProto, output_quantize_node: NodeProto, ): # Gemm -> (QLinearMatMul -> Add(bias)) weight_quantize_params = get_quantization_params( model, weight_quantize_node, include_target=True ) if weight_quantize_params.target is None: # weight initializer not included return gemm_attributes = get_node_attributes(gemm_node) if any(float(attribute) != 1.0 for attribute in gemm_attributes.values()): # can only handle Gemm operations without alpha/beta/transB set return # can fold the input/output quant ops if they are trivial fold_input_quant = input_quantize_node.op_type == "DequantizeLinear" fold_output_quant = output_quantize_node.op_type == "QuantizeLinear" # quantize weight quantized_weight = _quantize_array( weight_quantize_params.target, weight_quantize_params.scale, weight_quantize_params.zero_point, ) quantized_weight = quantized_weight.transpose() # Gemm has implicit transpose quantized_weight_name = "{}.weight_quantized".format(gemm_node.name) quantized_weight_initializer = numpy_helper.from_array( quantized_weight, name=quantized_weight_name ) model.graph.initializer.append(quantized_weight_initializer) # get qmatmul inputs and outputs qmatmul_input = ( input_quantize_node.input[0] if fold_input_quant else gemm_node.input[0] ) qmatmul_inputs = [ qmatmul_input, # x input_quantize_node.input[1], # x_scale input_quantize_node.input[2], # x_zero_point quantized_weight_name, # w weight_quantize_node.input[1], # w_scale weight_quantize_node.input[2], # w_zero_point output_quantize_node.input[1], # y_scale output_quantize_node.input[2], # y_zero_point ] qmatmul_output = ( output_quantize_node.output[0] if fold_output_quant else gemm_node.output[0] ) qmatmul_name = "{}_quant".format(gemm_node.name) # create qmatmul node and add it to graph qmatmul_node = onnx.helper.make_node( "QLinearMatMul", qmatmul_inputs, [qmatmul_output], qmatmul_name, ) model.graph.node.append(qmatmul_node) # delete folded quantization ops delete_quant_node(model, weight_dequantize_node, keep_params=False) delete_quant_node(model, weight_quantize_node, keep_params=True) if fold_input_quant and len(get_node_output_nodes(model, input_quantize_node)) <= 1: # fold if this gemm is the only node that reads from this quant op delete_quant_node(model, input_quantize_node, keep_params=True) if fold_output_quant: delete_quant_node(model, output_quantize_node, keep_params=True) if len(gemm_node.input) > 2: # add bias term following FC in the graph qmatmul_child_node = get_node_output_nodes(model, qmatmul_node) assert qmatmul_child_node, "QLinearMatMul node must have an output in the graph" dequant_output_name = "{}_dequantized".format(qmatmul_name) if qmatmul_child_node[0].op_type == "DequantizeLinear": qmatmul_dequantize_node = qmatmul_child_node[0] # create hidden output layer for bias add add_output_name = qmatmul_dequantize_node.output[0] swap_node_output(qmatmul_dequantize_node, dequant_output_name) else: # inject dequantize op for matmul qmatmul_output_name = "{}_output".format(qmatmul_name) swap_node_output(qmatmul_node, qmatmul_output_name) qmatmul_dequantize_node = onnx.helper.make_node( "DequantizeLinear", [ qmatmul_output_name, # input output_quantize_node.input[1], # scale output_quantize_node.input[2], # zero point ], [dequant_output_name], "{}_dequantize".format(qmatmul_name), ) model.graph.node.append(qmatmul_dequantize_node) add_output_name = qmatmul_output # original qmatmul output name # inject bias op for dequantized matmul output qmatmul_bias_add_node = onnx.helper.make_node( "Add", [ qmatmul_dequantize_node.output[0], # add input gemm_node.input[2], # Gemm bias ], [add_output_name], "{}_bias_add".format(gemm_node.name), ) model.graph.node.append(qmatmul_bias_add_node) # delete original Gemm node params_to_keep = [gemm_node.input[2]] if len(gemm_node.input) > 1 else [] remove_node_and_params_from_graph(model, gemm_node, keep_params=params_to_keep)
def __init__( self, model: Union[ModelProto, None], node: Union[Any, None], node_shape: Union[NodeShape, None] = None, **kwargs, ): if model is None and node is None: self._id = kwargs["id"] self._op_type = kwargs["op_type"] self._input_names = kwargs["input_names"] self._output_names = kwargs["output_names"] self._input_shapes = kwargs["input_shapes"] self._output_shapes = kwargs["output_shapes"] self._params = kwargs["params"] self._prunable = kwargs["prunable"] self._prunable_params_zeroed = kwargs["prunable_params_zeroed"] self._weight_name = kwargs["weight_name"] self._weight_shape = kwargs["weight_shape"] self._bias_name = kwargs["bias_name"] self._bias_shape = kwargs["bias_shape"] self._attributes = kwargs["attributes"] self._flops = kwargs["flops"] self._prunable_equation_sensitivity = ( kwargs["prunable_equation_sensitivity"] if "prunable_equation_sensitivity" in kwargs else None ) return if model is None or node is None: raise ValueError("both model and node must not be None") self._id = extract_node_id(node) self._op_type = node.op_type self._input_names = get_node_inputs(model, node) self._output_names = get_node_outputs(model, node) if node_shape is None: self._input_shapes = None self._output_shapes = None else: self._input_shapes = node_shape.input_shapes self._output_shapes = node_shape.output_shapes self._params = 0 self._prunable = is_prunable_node(model, node) self._prunable_params = 0 self._prunable_params_zeroed = 0 self._weight_name = None self._weight_shape = None self._bias_name = None self._bias_shape = None self._attributes = get_node_attributes(node) if self._prunable: weight, bias = get_node_params(model, node) self._params += weight.val.size self._prunable_params += weight.val.size self._prunable_params_zeroed += weight.val.size - numpy.count_nonzero( weight.val ) self._weight_name = weight.name self._weight_shape = [s for s in weight.val.shape] if bias is not None: self._bias_name = bias.name self._params += bias.val.size self._bias_shape = [s for s in bias.val.shape] kernel_shape = get_kernel_shape(self._attributes) self._flops = calculate_flops( self._op_type, input_shape=self._input_shapes, output_shape=self._output_shapes, weight_shape=self._weight_shape, kernel_shape=kernel_shape, bias_shape=self._bias_shape, attributes=self._attributes, ) self._prunable_equation_sensitivity = ( pruning_loss_sens_approx( self._input_shapes, self._output_shapes, self._params, apply_shape_change_mult=True, ) if self._prunable else None )