def forward(self, x: torch.Tensor) -> QuantTensor: if self.is_quant_enabled: out, scale, zero_point, bit_width = self.tensor_quant(x) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return QuantTensor(x, training=self.training)
def forward(self, x: torch.Tensor) -> QuantTensor: if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out, pre_scale, pre_zero_point, scale, zero_point, bit_width = impl(x) return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return QuantTensor(x, training=self.training)
def forward(self, x: Tensor, input_scale: Optional[Tensor] = None, input_bit_width: Optional[Tensor] = None) -> QuantTensor: if self.is_quant_enabled: if self.requires_input_scale and input_scale is None: raise RuntimeError("Input scale required") if self.requires_input_bit_width and input_bit_width is None: raise RuntimeError("Input bit-width required") if self.requires_input_scale and self.requires_input_bit_width: input_scale = input_scale.view(-1) out, out_scale, out_bit_width, out_zp = self.tensor_quant( x, input_scale, input_bit_width) elif self.requires_input_scale and not self.requires_input_bit_width: input_scale = input_scale.view(-1) out, out_scale, out_bit_width, out_zp = self.tensor_quant( x, input_scale) elif not self.requires_input_scale and not self.requires_input_bit_width: out, out_scale, out_bit_width, out_zp = self.tensor_quant(x) else: raise RuntimeError("Internally defined bit-width required") return QuantTensor(out, out_scale, out_bit_width, out_zp, self.is_signed, self.training) else: return QuantTensor(x, training=self.training)
def forward(self, x: QuantTensor): if self.is_quant_enabled: cleaned_up_value = round_ste( x.value / x.scale.detach()) * x.scale.detach() x = x.set(value=cleaned_up_value ) # clean up accumulated floating point errors trunc_bit_width = self.lsb_trunc_bit_width_impl(x.bit_width) trunc_scale = 2.0**trunc_bit_width output_scale = trunc_scale * x.scale if self.training: x, output_scale, x_bit_width = self.tensor_quant( x.value, output_scale, x.bit_width) else: # avoid fp errors at inference time x_bit_width = x.bit_width x = round_ste(x.value / x.scale) x = x / trunc_scale x = self.tensor_quant.int_quant.float_to_int_impl(x) x = x * output_scale x = x / trunc_scale output_scale = output_scale / trunc_scale # output_scale == input_scale output_bit_width = x_bit_width - trunc_bit_width return QuantTensor(x, output_scale, output_bit_width, self.is_signed) else: return x
def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: if self.is_act_enabled or self.is_quant_enabled: if isinstance(x, QuantTensor): x = x.value x = self.fused_activation_quant_proxy(x) return QuantTensor(*x, signed=self.is_signed) else: if isinstance(x, QuantTensor): # passthrough return x else: return QuantTensor(x)
def pack_output(self, quant_output: QuantTensor): if not self.training and self.cache_inference_quant_out: self._cached_out = _CachedIO(quant_output.detach(), self.cache_quant_io_metadata_only) if self.return_quant_tensor: return quant_output else: return quant_output.value
def forward(self, x: QuantTensor): if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple return QuantTensor(out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training) return x
def forward(self, x: QuantTensor): if self.is_quant_enabled: out_value, out_scale, out_bit_width = self.tensor_quant( x.value, x.scale, x.bit_width) return QuantTensor(out_value, out_scale, out_bit_width, x.signed) else: return x
def forward(self, inp): output_scale = None output_zp = None output_bit_width = None inp = self.unpack_input(inp) norm = inp.value.norm(p='fro', keepdim=True) + self.eps out = inp.value / norm out = nn.functional.linear( out, self.proj[:self.out_channels, :self.in_channels]) out = -self.scale * out if inp.scale is not None: output_scale = inp.scale * self.scale / norm if inp.bit_width is not None: output_bit_width = self.max_output_bit_width(inp.bit_width) if (self.return_quant_tensor and inp.zero_point is not None and (inp.zero_point != 0.0).any()): raise RuntimeError( "Computing zero point of output accumulator not supported yet." ) else: output_zp = inp.zero_point out = QuantTensor(value=out, scale=output_scale, zero_point=output_zp, bit_width=output_bit_width, signed=True, training=self.training) return out
def pack_output(self, output, output_scale, output_bit_width): if self.return_quant_tensor: return QuantTensor(tensor=output, scale=output_scale, bit_width=output_bit_width) else: return output
def test_brevitas_fc_onnx_export_and_exec(size, wbits, abits, pretrained): if size == "LFC" and wbits == 2 and abits == 2: pytest.skip(f"No LFC_{MAX_WBITS}W{MAX_ABITS}A present.") if wbits > abits: pytest.skip("No wbits > abits cases.") nname = f"{size}_{wbits}W{abits}A" finn_onnx = nname + ".onnx" fc, _ = model_with_cfg(nname.lower(), pretrained=pretrained) fc.eval() # load a random int test vector input_a = np.random.randint(MIN_INP_VAL, MAX_INP_VAL, size=FC_INPUT_SIZE).astype(np.float32) scale = 1. / 255 input_t = torch.from_numpy(input_a * scale) input_qt = QuantTensor( input_t, scale=torch.tensor(scale), bit_width=torch.tensor(8.0), signed=False) FINNManager.export(fc, export_path=finn_onnx, input_t=input_qt) model = ModelWrapper(finn_onnx) model = model.transform(GiveUniqueNodeNames()) model = model.transform(DoubleToSingleFloat()) model = model.transform(InferShapes()) model = model.transform(FoldConstants()) model = model.transform(RemoveStaticGraphInputs()) # run using FINN-based execution input_dict = {"0": input_a} output_dict = oxe.execute_onnx(model, input_dict) produced = output_dict[list(output_dict.keys())[0]] # do forward pass in PyTorch/Brevitas expected = fc.forward(input_t).detach().numpy() assert np.isclose(produced, expected, atol=ATOL).all()
def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: if self.is_act_enabled or self.is_quant_enabled: y = x if isinstance(y, QuantTensor): y = y.value y = self.fused_activation_quant_proxy(y) if isinstance(y, tuple): return QuantTensor(*y, signed=self.is_signed) elif self.passthrough_act: # preserve scale/bit/sign even without output quant return QuantTensor(y, x.scale, x.bit_width, x.signed) else: return QuantTensor(y) else: if isinstance(x, QuantTensor): # passthrough return x else: return QuantTensor(x)
def test_forward_bias_int(self): mod = QuantLinear( out_features=OUTPUT_FEATURES, in_features=INPUT_FEATURES, bias=True, bias_quant_type='INT') x = QuantTensor(torch.rand(size=(3, INPUT_FEATURES)), torch.tensor(1.0), torch.tensor(3)) assert mod(x) is not None
def forward(self, x: Union[Tensor, QuantTensor]) -> QuantTensor: if self.fused_activation_quant_proxy is not None: y = x if isinstance(y, QuantTensor): y = y.value y = self.fused_activation_quant_proxy(y) if isinstance(y, tuple): return QuantTensor(*y, signed=self.is_signed, training=self.training) elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant return QuantTensor( y, x.scale, x.zero_point, x.bit_width, x.signed, self.training) else: return QuantTensor(y, training=self.training) else: if isinstance(x, QuantTensor): # passthrough return x else: return QuantTensor(x, training=self.training)
def forward( self, x: Tensor, input_scale: Tensor, input_bit_width: Optional[Tensor]) -> QuantTensor: if self.is_quant_enabled: if input_scale is None: raise RuntimeError("Input scale can't be None when quantizing bias") input_scale = input_scale.view(-1) if self.requires_input_bit_width: # bit width is defined outside if input_bit_width is None: raise RuntimeError("Input or predefined bit width required") out, out_scale, out_bit_width = self.tensor_quant(x, input_scale, input_bit_width) else: out, out_scale, out_bit_width = self.tensor_quant(x, input_scale) return QuantTensor(out, out_scale, out_bit_width, self.is_signed) else: return QuantTensor(x)
def forward(self, x: QuantTensor): if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant out_tuple = impl(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple return QuantTensor(out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) else: return x
def forward(self, tensor_list: Union[List[Tensor], List[QuantTensor]], dim: int = 1) -> Union[Tensor, QuantTensor]: quant_tensor_list = [self.unpack_input(t) for t in tensor_list] # shortcut execution through the export impl during export if self.export_mode: return self.export_handler([qt.value for qt in quant_tensor_list]) quant_tensor_list = [self.input_quant(qt) for qt in quant_tensor_list] # trigger an assert if scale factors and bit widths are None or different output = QuantTensor.cat(quant_tensor_list, dim=dim) quant_output = self.output_quant(output) return self.pack_output(quant_output)
def test_export(): x = QuantTensor(torch.randn(IN_SIZE), scale=torch.tensor(2.0**(-7)), bit_width=torch.tensor(8), signed=True) mod = QuantModel() # Export quantized model to ONNX export_dpuv1_onnx(mod, input_shape=IN_SIZE, input_t=x, export_path='quant_model.onnx', input_names=["input_%d" % i for i in range(5)], output_names=["output"])
def forward_impl( self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: output_scale = None output_bit_width = None inp = self.unpack_input(inp) # shortcut execution through the export impl during export if self.export_mode: return self.export_handler(inp.value) quant_input = self.input_quant(inp) quant_weight = self.quant_weight() if quant_input.bit_width is not None: output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width) if quant_input.scale is not None: output_scale_shape = compute_channel_view_shape(inp, channel_dim=1) output_scale = quant_weight.scale.view(output_scale_shape) output_scale = output_scale * quant_input.scale.view( output_scale_shape) if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width) if not self.training and self.cache_inference_quant_bias: self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, quant_bias.value) if quant_bias.bit_width is not None and output_bit_width is not None: output_bit_width = torch.where( quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width) output_bit_width = output_bit_width + 1 else: output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, None) quant_output = QuantTensor(output_tensor, output_scale, output_bit_width, signed=True) quant_output = self.output_quant(quant_output) return self.pack_output(quant_output)
def unpack_input(self, inp: Union[Tensor, QuantTensor]): if isinstance(inp, QuantTensor): if self.export_mode: raise RuntimeError("QuantTensor I/O can't be used during export.") if not self.training and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp return inp else: inp = QuantTensor(inp) if not self.training and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp return inp
def unpack_input(self, inp: Union[Tensor, QuantTensor]): self._set_global_is_quant_layer(True) # Hack to recognize a QuantTensor that has decayed to a tuple # when used as input to tracing (e.g. during ONNX export) if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and len(inp) == len(QuantTensor._fields) and all([isinstance(t, Tensor) for t in inp])): inp = QuantTensor(*inp) if isinstance(inp, QuantTensor): # don't cache values during export pass if not self.training and not self._export_mode and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp return inp else: inp = QuantTensor(inp, training=self.training) if not self.training and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp return inp
def unpack_input(self, inp: Union[Tensor, QuantTensor]): if isinstance(inp, QuantTensor): # don't cache values during export pass if not self.training and not self._export_mode and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp return inp else: inp = QuantTensor(inp) if not self.training and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp return inp
def pack_output(self, output, output_scale, output_bit_width): if self._export_mode: # do not ever return QuantTensor while exporting # cached scale factors will be used in the next layer return output else: self.export_out_shape = output.shape # TODO control caching with own config variable self.export_out_scale = output_scale self.export_out_bit_width = output_bit_width if self.return_quant_tensor: return QuantTensor(tensor=output, scale=output_scale, bit_width=output_bit_width) else: return output
def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width, input_bit_width, channels, idim): quant_avgpool = QuantAvgPool2d(kernel_size=kernel_size, stride=stride, bit_width=bit_width) quant_avgpool.eval() # determine input prefix = 'INT' if signed else 'UINT' dt_name = prefix + str(input_bit_width) dtype = DataType[dt_name] input_shape = (1, channels, idim, idim) input_array = gen_finn_dt_tensor(dtype, input_shape) scale_array = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32) input_tensor = torch.from_numpy(input_array * scale_array).float() scale_tensor = torch.from_numpy(scale_array).float() zp = torch.tensor(0.) input_quant_tensor = QuantTensor(input_tensor, scale_tensor, zp, input_bit_width, signed, training=False) # export FINNManager.export(quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # reference brevitas output ref_output_array = quant_avgpool( input_quant_tensor).tensor.detach().numpy() # finn output idict = {model.graph.input[0].name: input_array} odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs assert np.isclose(ref_output_array, finn_output).all() # cleanup os.remove(export_onnx_path)
def __init__(self, quant_tensor: QuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: self.quant_tensor = quant_tensor.set(value=None) else: self.quant_tensor = quant_tensor
def forward(self, x: torch.Tensor) -> QuantTensor: if self.is_quant_enabled: out, scale, bit_width = self.tensor_quant(x) return QuantTensor(out, scale, bit_width, signed=self.is_signed) else: # quantization disabled return QuantTensor(x)
def test_brevitas_avg_pool_export(kernel_size, stride, signed, bit_width, input_bit_width, channels, idim): ishape = (1, channels, idim, idim) ibw_tensor = torch.Tensor([input_bit_width]) b_avgpool = QuantAvgPool2d( kernel_size=kernel_size, stride=stride, signed=signed, min_overall_bit_width=bit_width, max_overall_bit_width=bit_width, quant_type=QuantType.INT, ) # call forward pass manually once to cache scale factor and bitwidth input_tensor = torch.from_numpy(np.zeros(ishape)).float() scale = np.ones((1, channels, 1, 1)) output_scale = torch.from_numpy(scale).float() input_quant_tensor = QuantTensor(input_tensor, output_scale, ibw_tensor, signed) FINNManager.export_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) # determine input FINN datatype if signed is True: prefix = "INT" else: prefix = "UINT" dt_name = prefix + str(input_bit_width // 2) dtype = DataType[dt_name] model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # execution with input tensor using integers and scale = 1 # calculate golden output inp = gen_finn_dt_tensor(dtype, ishape) input_tensor = torch.from_numpy(inp).float() input_quant_tensor = QuantTensor(input_tensor, output_scale, ibw_tensor, signed) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() # finn execution idict = {model.graph.input[0].name: inp} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] assert (expected == produced).all() # execution with input tensor using float and scale != 1 scale = np.random.uniform(low=0, high=1, size=(1, channels, 1, 1)).astype(np.float32) inp_tensor = inp * scale input_tensor = torch.from_numpy(inp_tensor).float() input_scale = torch.from_numpy(scale).float() input_quant_tensor = QuantTensor(input_tensor, input_scale, ibw_tensor, signed) # export again to set the scale values correctly bo.export_finn_onnx(b_avgpool, ishape, export_onnx_path, input_t=input_quant_tensor) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) b_avgpool.eval() expected = b_avgpool.forward(input_quant_tensor).tensor.detach().numpy() # finn execution idict = {model.graph.input[0].name: inp_tensor} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] assert np.isclose(expected, produced).all() os.remove(export_onnx_path)
def bit_width(self): zhs = self._zero_hw_sentinel() empty_imp = QuantTensor(zhs, zhs, zhs, zhs) bit_width = self.__call__(empty_imp).bit_width return bit_width
def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: output_scale = None output_bit_width = None output_zero_point = None output_signed = None inp = self.unpack_input(inp) # shortcut execution through the export impl during export if self.export_mode: out = self.export_handler(inp.value) self._set_global_is_quant_layer(False) return out quant_input = self.input_quant(inp) quant_weight = self.quant_weight() if quant_input.bit_width is not None: output_bit_width = self.max_acc_bit_width(quant_input.bit_width, quant_weight.bit_width) if quant_input.scale is not None: output_scale_shape = compute_channel_view_shape(inp, channel_dim=1) output_scale = quant_weight.scale.view(output_scale_shape) output_scale = output_scale * quant_input.scale.view(output_scale_shape) if quant_input.signed is not None: output_signed = inp.signed or quant_weight.signed if self.bias is not None: quant_bias = self.bias_quant(self.bias, output_scale, output_bit_width) if not self.training and self.cache_inference_quant_bias: self._cached_bias = _CachedIO(quant_bias.detach(), metadata_only=False) output_tensor = self.inner_forward_impl( quant_input.value, quant_weight.value, quant_bias.value) if (output_scale is not None and (quant_bias.scale is None or (quant_bias.scale is not None and quant_bias.scale.data_ptr() != output_scale.data_ptr()))): output_zero_point = - quant_bias.value.view(output_scale_shape) / output_scale if quant_bias.bit_width is not None and output_bit_width is not None: output_bit_width = torch.where( quant_bias.bit_width > output_bit_width, quant_bias.bit_width, output_bit_width) output_bit_width = output_bit_width + 1 else: output_tensor = self.inner_forward_impl(quant_input.value, quant_weight.value, None) if self.return_quant_tensor and not self.is_output_quant_enabled: if (quant_input.zero_point is not None and ((quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any())): raise RuntimeError("Computing zero point of output accumulator not supported yet.") elif quant_input.zero_point is not None and output_zero_point is None: output_zero_point = quant_input.zero_point quant_output = QuantTensor( value=output_tensor, scale=output_scale, zero_point=output_zero_point, bit_width=output_bit_width, signed=output_signed, training=self.training) quant_output = self.output_quant(quant_output) return self.pack_output(quant_output)
def test_brevitas_avg_pool_export( kernel_size, stride, signed, bit_width, input_bit_width, channels, idim, QONNX_export, ): export_onnx_path = base_export_onnx_path.replace( ".onnx", f"test_QONNX-{QONNX_export}.onnx" ) quant_avgpool = QuantAvgPool2d( kernel_size=kernel_size, stride=stride, bit_width=bit_width, return_quant_tensor=False, ) quant_avgpool.eval() # determine input prefix = "INT" if signed else "UINT" dt_name = prefix + str(input_bit_width) dtype = DataType[dt_name] input_shape = (1, channels, idim, idim) input_array = gen_finn_dt_tensor(dtype, input_shape) # Brevitas QuantAvgPool layers need QuantTensors to export correctly # which requires setting up a QuantTensor instance with the scale # factor, zero point, bitwidth and signedness scale_array = np.ones((1, channels, 1, 1)).astype(np.float32) scale_array *= 0.5 input_tensor = torch.from_numpy(input_array * scale_array).float() scale_tensor = torch.from_numpy(scale_array).float() zp = torch.tensor(0.0) input_quant_tensor = QuantTensor( input_tensor, scale_tensor, zp, input_bit_width, signed, training=False ) # export if QONNX_export: BrevitasONNXManager.export( quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor, ) model = ModelWrapper(export_onnx_path) # Statically set the additional inputs generated by the BrevitasONNXManager model.graph.input.remove(model.graph.input[3]) model.graph.input.remove(model.graph.input[2]) model.graph.input.remove(model.graph.input[1]) model.set_initializer("1", scale_array) model.set_initializer("2", np.array(0.0).astype(np.float32)) model.set_initializer("3", np.array(input_bit_width).astype(np.float32)) model.save(export_onnx_path) qonnx_cleanup(export_onnx_path, out_file=export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(ConvertQONNXtoFINN()) model.save(export_onnx_path) else: FINNManager.export( quant_avgpool, export_path=export_onnx_path, input_t=input_quant_tensor ) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) model = model.transform(InferDataTypes()) # reference brevitas output ref_output_array = quant_avgpool(input_quant_tensor).detach().numpy() # finn output if QONNX_export: # Manually apply the Quant tensor scaling for QONNX idict = {model.graph.input[0].name: input_array * scale_array} else: idict = {model.graph.input[0].name: input_array} odict = oxe.execute_onnx(model, idict, True) finn_output = odict[model.graph.output[0].name] # compare outputs assert np.isclose(ref_output_array, finn_output).all() # cleanup # assert False os.remove(export_onnx_path)