def thresholds(module: QuantHardTanh, extend_tensor_to_channels=True): bit_width = int(module.quant_act_bit_width().item()) if bit_width != 1: if module.is_quant_act_narrow_range: # assuming narrow range, symmetric quantization around zero # when using narrow range, we represent one element less num_distinct_values = 2 ** bit_width - 1 else: num_distinct_values = 2 ** bit_width num_thresholds = num_distinct_values - 1 flat_scale = module.quant_act_scale().view(-1) num_scale_channels = flat_scale.shape[0] step = torch.abs(flat_scale) half_step = step / 2.0 thresholds = torch.empty(num_scale_channels, num_thresholds) # compute the value of the smallest threshold, we'll neg-bias all # generated thresholds by this much min_threshold = - half_step - step * ((num_thresholds // 2) - 1) if not module.is_quant_act_narrow_range: min_threshold -= step for c in range(num_scale_channels): for t in range(num_thresholds): thresholds[c][t] = min_threshold[c] + step[c] * t if extend_tensor_to_channels: output_channels = module._cached_inp.shape[1] final_shape = (output_channels, num_thresholds) if thresholds.shape != final_shape: thresholds = thresholds.expand(final_shape) return thresholds else: thresholds = torch.empty([1, 1]) thresholds[0] = 0 return thresholds
def quant_act_scale(module: QuantHardTanh): bit_width = int(module.quant_act_bit_width().item()) quant_act_scale = module.quant_act_scale().type(torch.FloatTensor).detach() if bit_width != 1: return quant_act_scale else: assert quant_act_scale.view(-1).shape[0] == 1, "Unsupported BIPOLAR per channel scale" assert quant_act_scale.flatten().item() == 1.0, "Unsupported BIPOLAR scale != 1" return quant_act_scale * 2
def test_brevitas_act_export_qhardtanh_scaled(abits, narrow_range, min_val, max_val, scaling_impl_type): def get_quant_type(bit_width): if bit_width is None: return QuantType.FP elif bit_width == 1: return QuantType.BINARY else: return QuantType.INT act_quant_type = get_quant_type(abits) ishape = (1, 15) b_act = QuantHardTanh( bit_width=abits, quant_type=act_quant_type, max_val=max_val, min_val=min_val, restrict_scaling_type=RestrictValueType.LOG_FP, scaling_impl_type=scaling_impl_type, narrow_range=narrow_range, ) if scaling_impl_type == ScalingImplType.PARAMETER: checkpoint = { "act_quant_proxy.fused_activation_quant_proxy.\ tensor_quant.scaling_impl.learned_value": torch.tensor(0.49).type(torch.FloatTensor) } b_act.load_state_dict(checkpoint) bo.export_finn_onnx(b_act, ishape, export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) inp_tensor = np.random.uniform(low=min_val, high=max_val, size=ishape).astype(np.float32) idict = {model.graph.input[0].name: inp_tensor} odict = oxe.execute_onnx(model, idict, True) produced = odict[model.graph.output[0].name] inp_tensor = torch.from_numpy(inp_tensor).float() b_act.eval() expected = b_act.forward(inp_tensor).detach().numpy() if not np.isclose(produced, expected, atol=1e-3).all(): print( "abits: ", abits, " | narrow_range: ", narrow_range, " | min_val: ", min_val, " | max_val: ", max_val, ) print("layer scale: ", b_act.quant_act_scale().type(torch.FloatTensor).detach()) print("export scale: ", b_act.export_act_scale) if abits < 5: print( "thres:", ", ".join(["{:8.4f}".format(x) for x in b_act.export_thres[0]]), ) print("input:", ", ".join(["{:8.4f}".format(x) for x in inp_tensor[0]])) print("prod :", ", ".join(["{:8.4f}".format(x) for x in produced[0]])) print("expec:", ", ".join(["{:8.4f}".format(x) for x in expected[0]])) assert np.isclose(produced, expected, atol=1e-3).all() os.remove(export_onnx_path)