예제 #1
0
def test_activation_quantizer_lsT_train_and_eval():
    """Test train_and_eval mode of activation quantizer for least squares ternary."""
    torch.manual_seed(1234)
    x = torch.ones(32, 16, 3, 3) * 2
    x2 = torch.rand(32, 16, 3, 3)  # some random, but all positive tensor
    x3 = torch.ones(32, 16, 3, 3) * 4

    quantizer_lsT_all_ma = ActivationQuantizerLST('train_and_eval', 0.9)

    quantizer_lsT_all_ma.train()
    # moving average should cause tracked v1 to become 1.0 after call
    x_q_train_all_ma = quantizer_lsT_all_ma(x)
    _, expected = quantization.quantizer_ls_ternary(x,
                                                    torch.tensor([1.0] * 32))
    assert torch.all(x_q_train_all_ma.eq(expected))
    # moving average should cause tracked v1 to become 1 * 0.9 + 2 * 0.1 = 1.1 after call
    x_q_train_all_ma = quantizer_lsT_all_ma(x3)
    _, expected = quantization.quantizer_ls_ternary(x,
                                                    torch.tensor([1.1] * 32))
    assert torch.all(x_q_train_all_ma.eq(expected))

    quantizer_lsT_all_ma.eval()
    x_q_eval_train_and_eval = quantizer_lsT_all_ma(x2)
    _, expected = quantization.quantizer_ls_ternary(x2,
                                                    torch.tensor([1.1] * 32))
    assert torch.all(x_q_eval_train_and_eval.eq(expected))
예제 #2
0
 def forward(self, w: torch.Tensor, skip: int = 3) -> torch.Tensor:  # type: ignore
     """Forward pass of quantizing weight using least squares ternary."""
     if self.training:
         v1, w_q = quantization.quantizer_ls_ternary(w, skip=skip)
         self.v1.copy_(v1)  # type: ignore
     else:
         _, w_q = quantization.quantizer_ls_ternary(w, self.v1, skip=skip)  # type: ignore
     return w_q
예제 #3
0
def test_quantizer_ls_T_all_inputs_equal():
    """Test ternary optimal least squares scaled binary quantization edge case."""
    torch.manual_seed(1234)
    x = torch.ones(32, 3, 16, 16) * 2
    _, x_q = quantization.quantizer_ls_ternary(x)

    assert torch.all(x_q == 2.0)

    # Test the case just certain rows have all elements equal
    x = torch.rand(32, 3, 16, 16)
    x[1, :, :, :] = torch.ones(3, 16, 16) * 2
    x[9, :, :, :] = torch.ones(3, 16, 16) * -3

    _, x_q = quantization.quantizer_ls_ternary(x)

    assert torch.all(x_q[1, :, :, :] == 2)
    assert torch.all(x_q[9, :, :, :] == -3)
예제 #4
0
def test_quantizer_ls2_better_than_lsT():
    """Test ls-2 is better than ls-T, which is better than ls-1."""
    torch.manual_seed(1234)
    x = torch.randn(1000, 3, 64, 64)

    _, _, x_q_ls2 = quantization.quantizer_ls_2(x, skip=1)
    _, x_q_lsT = quantization.quantizer_ls_ternary(x, skip=1)
    _, x_q_ls1 = quantization.quantizer_ls_1(x)

    ls2_costs = torch.norm((x_q_ls2 - x).view(1000, -1), dim=1)
    lsT_costs = torch.norm((x_q_lsT - x).view(1000, -1), dim=1)
    ls1_costs = torch.norm((x_q_ls1 - x).view(1000, -1), dim=1)

    assert torch.all(ls2_costs <= lsT_costs)
    assert torch.all(lsT_costs <= ls1_costs)
def test_weight_quantizer_lsT_modes():
    """Test training mode and eval mode for WeightQuantizerLST."""
    torch.manual_seed(1234)
    quantizer_lsT = weight_quantization.WeightQuantizerLST(32)
    w = torch.rand(32, 16, 3, 3)

    quantizer_lsT.train()
    _ = quantizer_lsT(w)
    v1 = quantizer_lsT.v1

    quantizer_lsT.eval()
    w = torch.rand(32, 16, 3, 3)  # some random, but all positive tensor
    w_q_eval = quantizer_lsT(w)
    _, w_q_eval_expected = quantization.quantizer_ls_ternary(w, v1=v1)

    assert torch.all(w_q_eval.eq(w_q_eval_expected))
예제 #6
0
def test_activation_quantizer_lsT_no_ma():
    """Test no moving average mode of activation quantizer for least squares ternary."""
    torch.manual_seed(1234)
    x = torch.ones(32, 16, 3, 3) * 2
    x2 = torch.rand(32, 16, 3, 3)  # some random, but all positive tensor

    quantizer_lsT_no_ma = ActivationQuantizerLST('off')
    quantizer_lsT_no_ma.train()
    quantizer_lsT_no_ma(x)
    x_q_train_no_ma = quantizer_lsT_no_ma(
        x)  # call twice so moving avg changes if used
    assert torch.all(x_q_train_no_ma == 2.0)

    quantizer_lsT_no_ma.eval()
    x_q_eval_no_ma = quantizer_lsT_no_ma(x2)
    # v1 should not be cached, so it should be recomputed
    _, expected = quantization.quantizer_ls_ternary(x2)
    assert torch.all(x_q_eval_no_ma.eq(expected))
예제 #7
0
def test_quantizer_ls_T_optimal():
    """Test ternary optimal least squares scaled binary quantization."""
    torch.manual_seed(1234)
    x = torch.randn(1000, 3, 64, 64)

    _, x_q = quantization.quantizer_ls_ternary(x, skip=1)
    assert x_q.shape == x.shape

    # Check x_q has lower least-squares error compared with using random scaling factors
    rand_indices = torch.randint(0, 3 * 64 * 64, (1000, ))
    subopt_v1 = x.view(1000, -1)[torch.arange(1000),
                                 rand_indices].view(1000, 1, 1, 1).abs()
    b1 = binarize(x)
    subopt_quantization = subopt_v1 * b1 + subopt_v1 * binarize(x -
                                                                subopt_v1 * b1)

    opt_costs = torch.norm((x_q - x).view(1000, -1), dim=1)
    subopt_costs = torch.norm((subopt_quantization - x).view(1000, -1), dim=1)
    assert torch.all(opt_costs <= subopt_costs)
예제 #8
0
 def _moving_average_quantization(self, x: torch.Tensor,
                                  vs: List[torch.Tensor]) -> torch.Tensor:
     """Return quantized x using vs."""
     v1 = vs[0]
     _, x_q = quantization.quantizer_ls_ternary(x, v1)
     return x_q
예제 #9
0
 def _batch_quantization(
         self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     """Return a 2-tuple of (scaling factors, quantized x)."""
     batch_v1, x_q = quantization.quantizer_ls_ternary(x)
     return batch_v1.view(1, -1), x_q