def test_ema_quantization(self):
     ema_decay = 0.9
     qlinear = QuantizedLinear(10,
                               5,
                               bias=False,
                               ema_decay=ema_decay,
                               mode="EMA")
     for i in range(5):
         x = torch.randn(3, 10)
         tmp_input_thresh = x.abs().max()
         if i == 0:
             input_ema = tmp_input_thresh
         else:
             input_ema -= (1 - ema_decay) * (input_ema - tmp_input_thresh)
         y = (fake_quantize_np(x, get_scale(8, input_ema), 8)
              @ qlinear.fake_quantized_weight.t()).detach()
         tmp_output_thresh = y.abs().max()
         if i == 0:
             output_ema = tmp_output_thresh
         else:
             output_ema -= (1 - ema_decay) * (output_ema -
                                              tmp_output_thresh)
         y = fake_quantize_np(y, get_scale(8, output_ema), 8)
         y_hat = qlinear(x)
         self.assertTrue((y == y_hat).all())
     self.assertEqual(qlinear.input_thresh, input_ema)
     self.assertEqual(qlinear.output_thresh, output_ema)
 def test_ema_quantization_data_parallel(self):
     if not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
         return
     ema_decay = 0.9
     fake_quantize = FakeLinearQuantizationWithSTE().apply
     qlinear = nn.DataParallel(
         QuantizedLinear(10, 5, bias=False, ema_decay=ema_decay,
                         mode="EMA")).cuda()
     for i in range(5):
         x = torch.randn(2, 10).cuda()
         tmp_input_thresh = x[0].abs().max()
         if i == 0:
             input_ema = tmp_input_thresh
         else:
             input_ema -= (1 - ema_decay) * (input_ema - tmp_input_thresh)
         y = (fake_quantize(x, get_scale(8, input_ema), 8)
              @ qlinear.module.fake_quantized_weight.t()).detach()
         tmp_output_thresh = y[0].abs().max()
         if i == 0:
             output_ema = tmp_output_thresh
         else:
             output_ema -= (1 - ema_decay) * \
                           (output_ema - tmp_output_thresh)
         qlinear(x)
     self.assertEqual(qlinear.module.input_thresh, input_ema)
     self.assertEqual(qlinear.module.output_thresh, output_ema)