def test_IntQuant(x, narrow_range, signed, bit_width, scale, int_scale, float_to_int_impl, scale_multiplier): float_to_int_impl_mock = Mock() tensor_clamp_impl = TensorClamp() value = torch.tensor(x) bit_width = torch.tensor(bit_width, dtype=torch.float) scale = torch.tensor(scale) int_scale = torch.tensor(int_scale) tol = scale * scale_multiplier float_to_int_impl_mock.side_effect = float_to_int_impl() obj = IntQuant(narrow_range=narrow_range, signed=signed, float_to_int_impl=float_to_int_impl_mock, tensor_clamp_impl=tensor_clamp_impl) output = obj(scale, int_scale, bit_width, value) min_value = int(min_int(signed, narrow_range, bit_width)) max_value = int(max_int(signed, bit_width)) admissible_values = [x for x in range(min_value, max_value + 1)] value = (value / scale) * int_scale expected_output = tensor_clamp(value, min_val=min_int(signed, narrow_range, bit_width), max_val=max_int(signed, bit_width)) expected_output = (expected_output / int_scale) * scale int_output = obj.to_int(scale, int_scale, bit_width, value) # The assert is performed internally check_admissible_values check_admissible_values(int_output, admissible_values) assert torch.allclose(expected_output, output, RTOL, tol)
def forward(self, input_bit_width: Tensor, zero_hw_sentinel: Tensor) -> Tensor: bit_width_to_remove = self.bit_width_to_remove_impl(zero_hw_sentinel) min_bit_width_to_remove = input_bit_width - self.max_overall_bit_width max_bit_width_to_remove = input_bit_width - self.min_overall_bit_width bit_width_to_remove = tensor_clamp(bit_width_to_remove, # pass gradient to boundaries min_bit_width_to_remove, # since input_bit_width is possibly learned max_bit_width_to_remove) return bit_width_to_remove
def forward(self, input_bit_width: Tensor) -> Tensor: bit_width_to_remove = self.bit_width_to_remove_impl() min_bit_width_to_remove = input_bit_width - self.max_overall_bit_width( ) max_bit_width_to_remove = input_bit_width - self.min_overall_bit_width( ) # pass gradient to boundaries since input_bit_width is possibly learned bit_width_to_remove = tensor_clamp(bit_width_to_remove, min_bit_width_to_remove, max_bit_width_to_remove) return bit_width_to_remove
def forward(self, x: Tensor, zero_hw_sentinel: Tensor) -> Tuple[Tensor, Tensor, Tensor]: scale = self.scaling_impl(zero_hw_sentinel) y = tensor_clamp(x, -scale, scale) y = binary_sign_ste(y) * scale return y, scale, zero_hw_sentinel + self.bit_width
def forward(ctx, x: Tensor, min_val: Tensor, max_val: Tensor) -> Tensor: y = tensor_clamp(x, min_val, max_val) return y
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: scale = self.scaling_impl(x) y = tensor_clamp(x, -scale, scale) y = binary_sign_ste(y) * scale y = self.delay_wrapper(x, y) return y, scale, self.bit_width()