Beispiel #1
0
 def test_fixed_point(self):
     for wl, fl in [(5, 4), (3, 2)]:
         for rounding in ["nearest"]:
             for device in [("cuda:%d" % d)
                            for d in range(torch.cuda.device_count())]:
                 t_max = 1 - (2**(-fl))
                 to_quantize_cuda = torch.linspace(
                     -t_max, t_max, steps=1200, device=torch.device(device))
                 to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
                 fixed_quantized_cuda = fixed_point_quantize(
                     to_quantize_cuda, wl=wl, fl=fl, rounding=rounding)
                 fixed_quantized_cpu = fixed_point_quantize(
                     to_quantize_cpu, wl=wl, fl=fl, rounding=rounding)
                 mse = self.error(fixed_quantized_cuda, fixed_quantized_cpu)
                 self.assertTrue(mse < 1e-15,
                                 msg="%.2e MSE on device '%s'" %
                                 (mse, device))
Beispiel #2
0
def QG(x, bits_G, bits_R, lr, mode="nearest"):
    x = shift(x)
    grad_number = FixedPoint(wl=bits_G,
                             fl=bits_G - 1,
                             clamp=False,
                             symmetric=True)
    norm = fixed_point_quantize(lr * x, grad_number, rounding=mode)
    return norm / (2.**((bits_G - 1)))
Beispiel #3
0
def QG(x, max_entry=None, bits_G=8, fl_mode="ceil", mode="nearest"):
    x, max_entry = shift(x, max_entry)
    norm = fixed_point_quantize(x,
                                wl=bits_G,
                                fl=bits_G - 1,
                                clamp=True,
                                symmetric=True,
                                rounding=mode)
    output = rebase(norm, max_entry)
    return output
Beispiel #4
0
def QG(x, bits_G, bits_R, lr, mode="nearest"):
    x = shift(x)
    lr = lr / (2.0**(bits_G - 1))
    norm = fixed_point_quantize(lr * x,
                                wl=bits_G,
                                fl=bits_G - 1,
                                clamp=False,
                                symmetric=True,
                                rounding=mode)
    return norm
Beispiel #5
0
 def test_fixed_point(self):
     for wl, fl in [(5, 4), (3, 2)]:
         for rounding in ["nearest"]:
             t_max = 1 - (2**(-fl))
             to_quantize_cuda = torch.linspace(-t_max,
                                               t_max,
                                               steps=20,
                                               device='cuda')
             to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
             fixed_quantized_cuda = fixed_point_quantize(to_quantize_cuda,
                                                         wl=wl,
                                                         fl=fl,
                                                         rounding=rounding)
             fixed_quantized_cpu = fixed_point_quantize(to_quantize_cpu,
                                                        wl=wl,
                                                        fl=fl,
                                                        rounding=rounding)
             mse = self.error(fixed_quantized_cuda, fixed_quantized_cpu)
             self.assertTrue(mse < 1e-15)
Beispiel #6
0
def QW(x, bits, scale=1.0, mode="nearest"):
    y = fixed_point_quantize(x,
                             wl=bits,
                             fl=bits - 1,
                             clamp=True,
                             symmetric=True,
                             rounding=mode)
    # per layer scaling
    if scale > 1.8: y /= scale
    return y
Beispiel #7
0
 def test_fixed_block_zero_exponent(self):
     """
     invariant: when the max exponent of a block is zero, block floating point behaves similar to
     a fixed point where fl = wl -1, without the lowest number of the fixed point (-1).
     """
     for wl, fl in [(3,2), (5,4)]:
         t_max = 1-(2**(-fl))
         to_quantize = torch.linspace(-t_max, t_max, steps=1000, device='cuda')
         fixed_quantized = fixed_point_quantize(to_quantize, wl=wl, fl=fl, rounding='nearest')
         block_quantized = block_quantize(to_quantize, wl=wl, rounding='nearest')
         self.assertTrue(torch.eq(fixed_quantized, block_quantized).all().item())