Beispiel #1
0
def test_IntQuant(x, narrow_range, signed, bit_width, scale, int_scale,
                  float_to_int_impl, scale_multiplier):
    float_to_int_impl_mock = Mock()
    tensor_clamp_impl = TensorClamp()

    value = torch.tensor(x)
    bit_width = torch.tensor(bit_width, dtype=torch.float)
    scale = torch.tensor(scale)
    int_scale = torch.tensor(int_scale)

    tol = scale * scale_multiplier
    float_to_int_impl_mock.side_effect = float_to_int_impl()

    obj = IntQuant(narrow_range=narrow_range,
                   signed=signed,
                   float_to_int_impl=float_to_int_impl_mock,
                   tensor_clamp_impl=tensor_clamp_impl)
    output = obj(scale, int_scale, bit_width, value)

    min_value = int(min_int(signed, narrow_range, bit_width))
    max_value = int(max_int(signed, bit_width))
    admissible_values = [x for x in range(min_value, max_value + 1)]

    value = (value / scale) * int_scale
    expected_output = tensor_clamp(value,
                                   min_val=min_int(signed, narrow_range,
                                                   bit_width),
                                   max_val=max_int(signed, bit_width))
    expected_output = (expected_output / int_scale) * scale

    int_output = obj.to_int(scale, int_scale, bit_width, value)

    # The assert is performed internally check_admissible_values
    check_admissible_values(int_output, admissible_values)
    assert torch.allclose(expected_output, output, RTOL, tol)
Beispiel #2
0
 def forward(self, input_bit_width: Tensor, zero_hw_sentinel: Tensor) -> Tensor:
     bit_width_to_remove = self.bit_width_to_remove_impl(zero_hw_sentinel)
     min_bit_width_to_remove = input_bit_width - self.max_overall_bit_width
     max_bit_width_to_remove = input_bit_width - self.min_overall_bit_width
     bit_width_to_remove = tensor_clamp(bit_width_to_remove,      # pass gradient to boundaries
                                        min_bit_width_to_remove,  # since input_bit_width is possibly learned
                                        max_bit_width_to_remove)
     return bit_width_to_remove
Beispiel #3
0
 def forward(self, input_bit_width: Tensor) -> Tensor:
     bit_width_to_remove = self.bit_width_to_remove_impl()
     min_bit_width_to_remove = input_bit_width - self.max_overall_bit_width(
     )
     max_bit_width_to_remove = input_bit_width - self.min_overall_bit_width(
     )
     # pass gradient to boundaries since input_bit_width is possibly learned
     bit_width_to_remove = tensor_clamp(bit_width_to_remove,
                                        min_bit_width_to_remove,
                                        max_bit_width_to_remove)
     return bit_width_to_remove
Beispiel #4
0
 def forward(self, x: Tensor,
             zero_hw_sentinel: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
     scale = self.scaling_impl(zero_hw_sentinel)
     y = tensor_clamp(x, -scale, scale)
     y = binary_sign_ste(y) * scale
     return y, scale, zero_hw_sentinel + self.bit_width
Beispiel #5
0
 def forward(ctx, x: Tensor, min_val: Tensor, max_val: Tensor) -> Tensor:
     y = tensor_clamp(x, min_val, max_val)
     return y
Beispiel #6
0
 def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
     scale = self.scaling_impl(x)
     y = tensor_clamp(x, -scale, scale)
     y = binary_sign_ste(y) * scale
     y = self.delay_wrapper(x, y)
     return y, scale, self.bit_width()