def test_fixed_point(self): for wl, fl in [(5, 4), (3, 2)]: for rounding in ["nearest"]: for device in [("cuda:%d" % d) for d in range(torch.cuda.device_count())]: t_max = 1 - (2**(-fl)) to_quantize_cuda = torch.linspace( -t_max, t_max, steps=1200, device=torch.device(device)) to_quantize_cpu = to_quantize_cuda.clone().to("cpu") fixed_quantized_cuda = fixed_point_quantize( to_quantize_cuda, wl=wl, fl=fl, rounding=rounding) fixed_quantized_cpu = fixed_point_quantize( to_quantize_cpu, wl=wl, fl=fl, rounding=rounding) mse = self.error(fixed_quantized_cuda, fixed_quantized_cpu) self.assertTrue(mse < 1e-15, msg="%.2e MSE on device '%s'" % (mse, device))
def QG(x, bits_G, bits_R, lr, mode="nearest"): x = shift(x) grad_number = FixedPoint(wl=bits_G, fl=bits_G - 1, clamp=False, symmetric=True) norm = fixed_point_quantize(lr * x, grad_number, rounding=mode) return norm / (2.**((bits_G - 1)))
def QG(x, max_entry=None, bits_G=8, fl_mode="ceil", mode="nearest"): x, max_entry = shift(x, max_entry) norm = fixed_point_quantize(x, wl=bits_G, fl=bits_G - 1, clamp=True, symmetric=True, rounding=mode) output = rebase(norm, max_entry) return output
def QG(x, bits_G, bits_R, lr, mode="nearest"): x = shift(x) lr = lr / (2.0**(bits_G - 1)) norm = fixed_point_quantize(lr * x, wl=bits_G, fl=bits_G - 1, clamp=False, symmetric=True, rounding=mode) return norm
def test_fixed_point(self): for wl, fl in [(5, 4), (3, 2)]: for rounding in ["nearest"]: t_max = 1 - (2**(-fl)) to_quantize_cuda = torch.linspace(-t_max, t_max, steps=20, device='cuda') to_quantize_cpu = to_quantize_cuda.clone().to("cpu") fixed_quantized_cuda = fixed_point_quantize(to_quantize_cuda, wl=wl, fl=fl, rounding=rounding) fixed_quantized_cpu = fixed_point_quantize(to_quantize_cpu, wl=wl, fl=fl, rounding=rounding) mse = self.error(fixed_quantized_cuda, fixed_quantized_cpu) self.assertTrue(mse < 1e-15)
def QW(x, bits, scale=1.0, mode="nearest"): y = fixed_point_quantize(x, wl=bits, fl=bits - 1, clamp=True, symmetric=True, rounding=mode) # per layer scaling if scale > 1.8: y /= scale return y
def test_fixed_block_zero_exponent(self): """ invariant: when the max exponent of a block is zero, block floating point behaves similar to a fixed point where fl = wl -1, without the lowest number of the fixed point (-1). """ for wl, fl in [(3,2), (5,4)]: t_max = 1-(2**(-fl)) to_quantize = torch.linspace(-t_max, t_max, steps=1000, device='cuda') fixed_quantized = fixed_point_quantize(to_quantize, wl=wl, fl=fl, rounding='nearest') block_quantized = block_quantize(to_quantize, wl=wl, rounding='nearest') self.assertTrue(torch.eq(fixed_quantized, block_quantized).all().item())