def test_round_to_zero(x): """ Test round_to_zero against np.fix on random float tensors """ output = round_to_zero(x) reference = torch.from_numpy(np.fix(x.numpy())) assert_allclose(output, reference)
def test_bwd(self, x: Tuple[Tensor, Tensor], ste_impl: Callable): """ Test that gradients are correctly passed through """ value, grad = x value.requires_grad_(True) output = ste_impl(value) output.backward(grad, retain_graph=True) assert_allclose(grad, value.grad)
def test_bwd(self, x): """ Test that gradients are correctly passed through to val only """ min_val, val, val_grad = x val.requires_grad_(True) output = scalar_clamp_min_ste_impl(val, min_val) output.backward(val_grad, retain_graph=True) assert_allclose(val_grad, val.grad)
def test_bwd(self, bit_width_parameter: BitWidthParameter, bit_width_grad): """ Test that gradients are propagated to bit_width_parameter.bit_width_offset """ bit_width_tensor = bit_width_parameter() bit_width_tensor.backward(bit_width_grad) assert_allclose(bit_width_parameter.bit_width_offset.grad, bit_width_grad) self.clean_up_bwd(bit_width_parameter)
def test_bwd_nz(self, inp, grad): """ Test that the backward pass matches torch.abs backward for inp != 0 """ import torch inp.requires_grad_(True) output = abs_binary_sign_grad_impl(inp) output.backward(grad) reference_inp = inp.detach().clone().requires_grad_(True) reference_output = torch.abs(reference_inp) reference_output.backward(grad) assert_allclose(inp.grad, reference_inp.grad)
def test_bwd_zero(self, grad): """ Test that the subgradient w.r.t. inp == 0 is 1 and not 0 """ import torch inp = tensor(0.0) inp.requires_grad_(True) output = abs_binary_sign_grad_impl(inp) output.backward(grad) reference_inp = inp.detach().clone().requires_grad_(True) reference_output = torch.abs(reference_inp) reference_output.backward(grad) assert_allclose(inp.grad, grad) assert reference_output == 0.0
def test_output_scale(self, ternary_quant, scaling_impl_all, inp): _, scale, _, _ = ternary_quant(inp) assert_allclose(scale, scaling_impl_all(inp))
def test_output_zero_point(self, ternary_quant, inp): _, _, zero_point, _ = ternary_quant(inp) assert_allclose(zero_point, torch.tensor(0.0))
def test_output_bit_width(self, ternary_quant, inp): _, _, _, bit_width = ternary_quant(inp) assert_allclose(bit_width, torch.tensor(2.0))
def test_output_bit_width(self, binary_quant_all, inp): _, _, _, bit_width = binary_quant_all(inp) assert_allclose(bit_width, torch.tensor(1.0))