def test_block_floating_point(self): for wl in [5, 3]: for rounding in ["nearest"]: for dim in [-1, 0, 1]: t_max = 1 - (2**(-4)) to_quantize_cuda = torch.linspace(-t_max, t_max, steps=20, device='cuda') to_quantize_cpu = to_quantize_cuda.clone().to("cpu") block_quantized_cuda = block_quantize(to_quantize_cuda, wl=wl, rounding=rounding) block_quantized_cpu = block_quantize(to_quantize_cpu, wl=wl, rounding=rounding) mse = self.error(block_quantized_cuda, block_quantized_cpu) self.assertTrue(mse < 1e-15)
def test_fixed_block_zero_exponent(self): """ invariant: when the max exponent of a block is zero, block floating point behaves similar to a fixed point where fl = wl -1, without the lowest number of the fixed point (-1). """ for wl, fl in [(3,2), (5,4)]: t_max = 1-(2**(-fl)) to_quantize = torch.linspace(-t_max, t_max, steps=1000, device='cuda') fixed_quantized = fixed_point_quantize(to_quantize, wl=wl, fl=fl, rounding='nearest') block_quantized = block_quantize(to_quantize, wl=wl, rounding='nearest') self.assertTrue(torch.eq(fixed_quantized, block_quantized).all().item())
def test_block_floating_point(self): for wl in [5, 3]: for rounding in ["nearest"]: for dim in [-1, 0, 1]: for device in [("cuda:%d" % d) for d in range(torch.cuda.device_count())]: t_max = 1 - (2**(-4)) to_quantize_cuda = torch.linspace( -t_max, t_max, steps=1200, device=torch.device(device)) to_quantize_cpu = to_quantize_cuda.clone().to("cpu") block_quantized_cuda = block_quantize( to_quantize_cuda, wl=wl, rounding=rounding) block_quantized_cpu = block_quantize(to_quantize_cpu, wl=wl, rounding=rounding) mse = self.error(block_quantized_cuda, block_quantized_cpu) self.assertTrue(mse < 1e-15, msg="%.2e MSE on device '%s'" % (mse, device))