Exemplo n.º 1
0
 def test_block_floating_point(self):
     for wl in [5, 3]:
         for rounding in ["nearest"]:
             for dim in [-1, 0, 1]:
                 t_max = 1 - (2**(-4))
                 to_quantize_cuda = torch.linspace(-t_max,
                                                   t_max,
                                                   steps=20,
                                                   device='cuda')
                 to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
                 block_quantized_cuda = block_quantize(to_quantize_cuda,
                                                       wl=wl,
                                                       rounding=rounding)
                 block_quantized_cpu = block_quantize(to_quantize_cpu,
                                                      wl=wl,
                                                      rounding=rounding)
                 mse = self.error(block_quantized_cuda, block_quantized_cpu)
                 self.assertTrue(mse < 1e-15)
Exemplo n.º 2
0
 def test_fixed_block_zero_exponent(self):
     """
     invariant: when the max exponent of a block is zero, block floating point behaves similar to
     a fixed point where fl = wl -1, without the lowest number of the fixed point (-1).
     """
     for wl, fl in [(3,2), (5,4)]:
         t_max = 1-(2**(-fl))
         to_quantize = torch.linspace(-t_max, t_max, steps=1000, device='cuda')
         fixed_quantized = fixed_point_quantize(to_quantize, wl=wl, fl=fl, rounding='nearest')
         block_quantized = block_quantize(to_quantize, wl=wl, rounding='nearest')
         self.assertTrue(torch.eq(fixed_quantized, block_quantized).all().item())
Exemplo n.º 3
0
 def test_block_floating_point(self):
     for wl in [5, 3]:
         for rounding in ["nearest"]:
             for dim in [-1, 0, 1]:
                 for device in [("cuda:%d" % d)
                                for d in range(torch.cuda.device_count())]:
                     t_max = 1 - (2**(-4))
                     to_quantize_cuda = torch.linspace(
                         -t_max,
                         t_max,
                         steps=1200,
                         device=torch.device(device))
                     to_quantize_cpu = to_quantize_cuda.clone().to("cpu")
                     block_quantized_cuda = block_quantize(
                         to_quantize_cuda, wl=wl, rounding=rounding)
                     block_quantized_cpu = block_quantize(to_quantize_cpu,
                                                          wl=wl,
                                                          rounding=rounding)
                     mse = self.error(block_quantized_cuda,
                                      block_quantized_cpu)
                     self.assertTrue(mse < 1e-15,
                                     msg="%.2e MSE on device '%s'" %
                                     (mse, device))