def test_activation_quantizer_lsT_train_and_eval(): """Test train_and_eval mode of activation quantizer for least squares ternary.""" torch.manual_seed(1234) x = torch.ones(32, 16, 3, 3) * 2 x2 = torch.rand(32, 16, 3, 3) # some random, but all positive tensor x3 = torch.ones(32, 16, 3, 3) * 4 quantizer_lsT_all_ma = ActivationQuantizerLST('train_and_eval', 0.9) quantizer_lsT_all_ma.train() # moving average should cause tracked v1 to become 1.0 after call x_q_train_all_ma = quantizer_lsT_all_ma(x) _, expected = quantization.quantizer_ls_ternary(x, torch.tensor([1.0] * 32)) assert torch.all(x_q_train_all_ma.eq(expected)) # moving average should cause tracked v1 to become 1 * 0.9 + 2 * 0.1 = 1.1 after call x_q_train_all_ma = quantizer_lsT_all_ma(x3) _, expected = quantization.quantizer_ls_ternary(x, torch.tensor([1.1] * 32)) assert torch.all(x_q_train_all_ma.eq(expected)) quantizer_lsT_all_ma.eval() x_q_eval_train_and_eval = quantizer_lsT_all_ma(x2) _, expected = quantization.quantizer_ls_ternary(x2, torch.tensor([1.1] * 32)) assert torch.all(x_q_eval_train_and_eval.eq(expected))
def forward(self, w: torch.Tensor, skip: int = 3) -> torch.Tensor: # type: ignore """Forward pass of quantizing weight using least squares ternary.""" if self.training: v1, w_q = quantization.quantizer_ls_ternary(w, skip=skip) self.v1.copy_(v1) # type: ignore else: _, w_q = quantization.quantizer_ls_ternary(w, self.v1, skip=skip) # type: ignore return w_q
def test_quantizer_ls_T_all_inputs_equal(): """Test ternary optimal least squares scaled binary quantization edge case.""" torch.manual_seed(1234) x = torch.ones(32, 3, 16, 16) * 2 _, x_q = quantization.quantizer_ls_ternary(x) assert torch.all(x_q == 2.0) # Test the case just certain rows have all elements equal x = torch.rand(32, 3, 16, 16) x[1, :, :, :] = torch.ones(3, 16, 16) * 2 x[9, :, :, :] = torch.ones(3, 16, 16) * -3 _, x_q = quantization.quantizer_ls_ternary(x) assert torch.all(x_q[1, :, :, :] == 2) assert torch.all(x_q[9, :, :, :] == -3)
def test_quantizer_ls2_better_than_lsT(): """Test ls-2 is better than ls-T, which is better than ls-1.""" torch.manual_seed(1234) x = torch.randn(1000, 3, 64, 64) _, _, x_q_ls2 = quantization.quantizer_ls_2(x, skip=1) _, x_q_lsT = quantization.quantizer_ls_ternary(x, skip=1) _, x_q_ls1 = quantization.quantizer_ls_1(x) ls2_costs = torch.norm((x_q_ls2 - x).view(1000, -1), dim=1) lsT_costs = torch.norm((x_q_lsT - x).view(1000, -1), dim=1) ls1_costs = torch.norm((x_q_ls1 - x).view(1000, -1), dim=1) assert torch.all(ls2_costs <= lsT_costs) assert torch.all(lsT_costs <= ls1_costs)
def test_weight_quantizer_lsT_modes(): """Test training mode and eval mode for WeightQuantizerLST.""" torch.manual_seed(1234) quantizer_lsT = weight_quantization.WeightQuantizerLST(32) w = torch.rand(32, 16, 3, 3) quantizer_lsT.train() _ = quantizer_lsT(w) v1 = quantizer_lsT.v1 quantizer_lsT.eval() w = torch.rand(32, 16, 3, 3) # some random, but all positive tensor w_q_eval = quantizer_lsT(w) _, w_q_eval_expected = quantization.quantizer_ls_ternary(w, v1=v1) assert torch.all(w_q_eval.eq(w_q_eval_expected))
def test_activation_quantizer_lsT_no_ma(): """Test no moving average mode of activation quantizer for least squares ternary.""" torch.manual_seed(1234) x = torch.ones(32, 16, 3, 3) * 2 x2 = torch.rand(32, 16, 3, 3) # some random, but all positive tensor quantizer_lsT_no_ma = ActivationQuantizerLST('off') quantizer_lsT_no_ma.train() quantizer_lsT_no_ma(x) x_q_train_no_ma = quantizer_lsT_no_ma( x) # call twice so moving avg changes if used assert torch.all(x_q_train_no_ma == 2.0) quantizer_lsT_no_ma.eval() x_q_eval_no_ma = quantizer_lsT_no_ma(x2) # v1 should not be cached, so it should be recomputed _, expected = quantization.quantizer_ls_ternary(x2) assert torch.all(x_q_eval_no_ma.eq(expected))
def test_quantizer_ls_T_optimal(): """Test ternary optimal least squares scaled binary quantization.""" torch.manual_seed(1234) x = torch.randn(1000, 3, 64, 64) _, x_q = quantization.quantizer_ls_ternary(x, skip=1) assert x_q.shape == x.shape # Check x_q has lower least-squares error compared with using random scaling factors rand_indices = torch.randint(0, 3 * 64 * 64, (1000, )) subopt_v1 = x.view(1000, -1)[torch.arange(1000), rand_indices].view(1000, 1, 1, 1).abs() b1 = binarize(x) subopt_quantization = subopt_v1 * b1 + subopt_v1 * binarize(x - subopt_v1 * b1) opt_costs = torch.norm((x_q - x).view(1000, -1), dim=1) subopt_costs = torch.norm((subopt_quantization - x).view(1000, -1), dim=1) assert torch.all(opt_costs <= subopt_costs)
def _moving_average_quantization(self, x: torch.Tensor, vs: List[torch.Tensor]) -> torch.Tensor: """Return quantized x using vs.""" v1 = vs[0] _, x_q = quantization.quantizer_ls_ternary(x, v1) return x_q
def _batch_quantization( self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Return a 2-tuple of (scaling factors, quantized x).""" batch_v1, x_q = quantization.quantizer_ls_ternary(x) return batch_v1.view(1, -1), x_q