예제 #1
0
    def test_max(self):
        torch.manual_seed(12345)
        ref_lenet = QuantLeNet()
        torch.manual_seed(12345)
        test_lenet = QuantLeNet()

        for module in ref_lenet.modules():
            if isinstance(module,
                          (quant_nn.QuantConv2d, quant_nn.QuantLinear)):
                module.weight_quantizer.enable_calib()
                module.weight_quantizer.disable_quant()
                module.weight_quantizer(module.weight)
                module.weight_quantizer.load_calib_amax()

        calib.calibrate_weights(test_lenet, method="max")

        for ref_module, test_module in zip(ref_lenet.modules(),
                                           test_lenet.modules()):
            if isinstance(ref_module,
                          (quant_nn.QuantConv2d, quant_nn.QuantLinear)):
                test_utils.compare(ref_module.weight_quantizer.amax,
                                   test_module.weight_quantizer.amax,
                                   rtol=0,
                                   atol=0,
                                   ctol=0)
                assert ref_module.weight_quantizer.amax.shape == test_module.weight_quantizer.amax.shape
예제 #2
0
def test_data_augmentation_collapse():
    dataset = DatasetImpl(path=data_path,
                          shape=[256, 256],
                          augmentation=True,
                          collapse_length=2,
                          is_raw=True)
    gt_idx, gt_k, gt_flip, gt_angle = 1, 2, False, 0
    gt_box, gt_seq_length = np.array([0, 0, 260, 346]), 1
    events, timestamps, images, aug_params = dataset.__getitem__(
        idx=gt_idx,
        k=gt_k,
        is_flip=gt_flip,
        angle=gt_angle,
        box=gt_box,
        seq_length=gt_seq_length)
    assert gt_idx == aug_params[0]
    assert gt_seq_length == aug_params[1]
    assert gt_k == aug_params[2]
    assert (gt_box == aug_params[3]).all()
    assert gt_angle == aug_params[4]
    assert gt_flip == aug_params[5]

    element1 = tuple(read_test_elem(1, element_index=0, box=gt_box))
    element2 = tuple(read_test_elem(2, element_index=0, box=gt_box))
    gt_events = concat_events(element1[0], element2[0])
    gt_timestamps = np.array([0, element2[2] - element1[1]])
    gt_events['timestamp'] -= element1[1]
    assert element1[2] == element2[1]
    assert (element1[4] == element2[3]).all()
    gt_images = np.concatenate([element1[3][None], element2[4][None]],
                               axis=0).astype(np.float32)

    compare(events, gt_events)
    assert (timestamps == gt_timestamps).all()
    assert (images == gt_images).all()
예제 #3
0
    def test_cuda_ext(self):
        x_np = np.random.rand(1023).astype('float32')
        x_torch = torch.Tensor(x_np).cuda()

        for num_bits in [3, 4, 5, 7, 8, 11]:
            for unsigned in [True, False]:
                test_utils.compare(cuda_ext.fake_tensor_quant(
                    x_torch, torch.max(torch.abs(x_torch)), num_bits,
                    unsigned),
                                   tensor_quant.fake_tensor_quant(
                                       x_torch, torch.max(torch.abs(x_torch)),
                                       num_bits, unsigned),
                                   rtol=0,
                                   atol=0)

        # Test fp16
        x_np_fp16 = np.random.rand(1023).astype('float16')
        x_torch_fp16 = torch.Tensor(x_np_fp16).cuda().half()
        test_utils.compare(
            cuda_ext.fake_tensor_quant(x_torch_fp16,
                                       torch.max(torch.abs(x_torch_fp16))),
            tensor_quant.fake_tensor_quant(x_torch_fp16,
                                           torch.max(torch.abs(x_torch_fp16))),
            rtol=0,
            atol=0)
예제 #4
0
    def test_against_unquantized(self):
        kernel_size = 3
        test_input = torch.randn(16, _NUM_IN_CHANNELS, 24, 24, 24).cuda()

        torch.manual_seed(1234)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(1234)
        fake_quant_conv3d = quant_conv.QuantConv3d(
            _NUM_IN_CHANNELS,
            _NUM_OUT_CHANNELS,
            kernel_size,
            bias=True,
            quant_desc_input=QuantDescriptor(num_bits=16),
            quant_desc_weight=QuantDescriptor(num_bits=16, axis=(0)))

        # Reset seed. Make sure weight and bias are the same
        torch.manual_seed(1234)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(1234)
        conv3d = nn.Conv3d(_NUM_IN_CHANNELS, _NUM_OUT_CHANNELS, kernel_size, bias=True)

        fake_quant_output = fake_quant_conv3d(test_input)
        output = conv3d(test_input)

        test_utils.compare(fake_quant_output, output, rtol=1e-6, atol=2e-4)
예제 #5
0
    def test_mse_with_axis(self):
        torch.manual_seed(12345)
        test_lenet = QuantLeNet()

        ref_calibrator = calib.HistogramCalibrator(8, None, False)

        calib.calibrate_weights(test_lenet, method="mse", perchannel=True)
        ref_calibrator.collect(test_lenet.conv2.weight[1])
        ref_amax = ref_calibrator.compute_amax("mse")
        test_utils.compare(ref_amax, test_lenet.conv2.weight_quantizer.amax[1], rtol=0, atol=0, ctol=0)
예제 #6
0
    def test_percentile(self):
        torch.manual_seed(12345)
        test_lenet = QuantLeNet()
        test_percentile = 99.99

        ref_calibrator = calib.HistogramCalibrator(8, None, False)

        calib.calibrate_weights(test_lenet, method="percentile", perchannel=False, percentile=test_percentile)
        ref_calibrator.collect(test_lenet.conv1.weight)
        ref_amax = ref_calibrator.compute_amax("percentile", percentile=test_percentile)
        test_utils.compare(ref_amax, test_lenet.conv1.weight_quantizer.amax, rtol=0, atol=0, ctol=0)
예제 #7
0
    def test_track_amax(self):
        max_calibrator = calib.MaxCalibrator(8, None, False, track_amax=True)

        x_1 = torch.rand(129).cuda()
        x_2 = torch.rand(127).cuda()
        max_calibrator.collect(x_1)
        max_calibrator.collect(x_2)

        test_utils.compare(max_calibrator.compute_amax(), torch.max(x_1.max(), x_2.max()), atol=0, rtol=0, ctol=0)
        np.testing.assert_array_equal(max_calibrator.amaxs[0], x_1.max().cpu().numpy())
        np.testing.assert_array_equal(max_calibrator.amaxs[1], x_2.max().cpu().numpy())
예제 #8
0
    def test_simple_run(self):
        max_calibrator = calib.MaxCalibrator(8, None, False)

        x_1 = torch.rand(129).cuda()
        x_2 = torch.rand(127).cuda()
        max_calibrator.collect(x_1)
        max_calibrator.collect(x_2)

        test_utils.compare(max_calibrator.compute_amax(), torch.max(x_1.max(), x_2.max()), atol=0, rtol=0, ctol=0)

        # Nothing to test other than creation
        max_calibrator = calib.MaxCalibrator(8, None, True)
예제 #9
0
    def test_cuda_ext_with_axis(self):
        x_np = np.random.rand(3, 4, 5, 6).astype('float32')
        x_torch = torch.Tensor(x_np).cuda()

        # amax along axis 1
        amax_torch = torch.tensor([0.8, 0.9, 0.7, 0.6], device="cuda")

        for num_bits in [3, 4, 5, 7, 8, 11]:
            for unsigned in [True, False]:
                cuda_ext_out = cuda_ext.fake_tensor_quant_with_axis(
                    x_torch, amax_torch, 1, num_bits, unsigned)
                pytorch_out = tensor_quant.fake_tensor_quant(
                    x_torch, amax_torch.view(1, -1, 1, 1), num_bits, unsigned)
                test_utils.compare(cuda_ext_out, pytorch_out, rtol=0, atol=0)
예제 #10
0
    def test_fine_grain(self):
        axis = 0
        reducs_axis = (1, 2, 3)
        max_calibrator = calib.MaxCalibrator(8, axis, False)

        x_1 = torch.rand(31, 63, 7, 7).cuda()
        x_2 = torch.rand(31, 63, 7, 7).cuda()
        max_calibrator.collect(x_1)
        max_calibrator.collect(x_2)

        assert max_calibrator.compute_amax().shape[0] == 31

        test_utils.compare(max_calibrator.compute_amax(),
                           quant_utils.reduce_amax(torch.max(x_1, x_2), axis=reducs_axis),
                           atol=0, rtol=0, ctol=0)

        max_calibrator.reset()
        assert max_calibrator.compute_amax() is None
예제 #11
0
    def test_entropy_and_percentile_calib(self):
        """Don't really have a good way to test it."""
        quant_desc1 = tensor_quant.QuantDescriptor(calib_method='histogram')
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1, if_calib=True, if_quant=False).cuda()

        x_1 = torch.rand(3, 63, 7, 7).cuda()
        x_2 = torch.rand(3, 63, 7, 7).cuda()
        quantizer1(x_1)
        quantizer1(x_2)

        quantizer1.load_calib_amax("entropy")
        test_utils.compare(quantizer1._calibrator.compute_amax("entropy"), quantizer1.amax, atol=0, rtol=0, ctol=0)
        quantizer1._calibrator.reset()

        quantizer1(x_1)
        quantizer1(x_2)

        quantizer1.load_calib_amax("percentile", percentile=99.99)
        test_utils.compare(quantizer1._calibrator.compute_amax(
            "percentile", percentile=99.99), quantizer1.amax, atol=0, rtol=0, ctol=0)
예제 #12
0
def test_dataloader():
    dataset = DatasetImpl(path=data_path,
                          shape=[260, 346],
                          augmentation=False,
                          collapse_length=1,
                          is_raw=True)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              collate_fn=collate_wrapper,
                                              batch_size=2,
                                              pin_memory=True,
                                              shuffle=False)
    batch = next(iter(data_loader))

    element1 = tuple(read_test_elem(0, element_index=0, is_torch=True))
    element2 = tuple(read_test_elem(1, element_index=0, is_torch=True))
    element1[0]['timestamp'] -= element1[1]
    element2[0]['timestamp'] -= element2[1]
    gt_events = concat_events(element1[0], element2[0])
    gt_events['sample_index'] = np.hstack(
        [np.full_like(element1[0]['x'], 0),
         np.full_like(element2[0]['x'], 1)])
    gt_events = {k: torch.tensor(v) for k, v in gt_events.items()}
    gt_timestamps = torch.tensor(
        [0, element1[2] - element1[1], 0, element2[2] - element2[1]],
        dtype=torch.float32)
    gt_sample_idx = torch.tensor([0, 0, 1, 1], dtype=torch.long)
    image00 = torch.tensor(element1[3], dtype=torch.float32)[None, None]
    image01 = torch.tensor(element1[4], dtype=torch.float32)[None, None]
    image10 = torch.tensor(element2[3], dtype=torch.float32)[None, None]
    image11 = torch.tensor(element2[4], dtype=torch.float32)[None, None]
    gt_images = torch.cat([image00, image01, image10, image11], dim=0) \
                     .to(torch.float32)

    compare(batch['events'], gt_events)
    assert torch.equal(batch['timestamps'], gt_timestamps)
    assert torch.equal(batch['sample_idx'], gt_sample_idx)
    assert (batch['images'] == gt_images).all()
    assert batch['size'] == 2
예제 #13
0
    def test_max_calib(self):
        axis = 0
        reduce_axis = (1, 2, 3)
        quant_desc1 = tensor_quant.QuantDescriptor(axis=axis)
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1).cuda()
        quantizer1.enable_calib()

        quant_desc1 = tensor_quant.QuantDescriptor(axis=axis)
        quantizer1 = tensor_quantizer.TensorQuantizer(quant_desc1).cuda()
        quantizer1.enable_calib()

        with pytest.raises(RuntimeError, match="Calibrator returned None"):
            quantizer1.load_calib_amax()

        x_1 = torch.rand(127, 63, 7, 7).cuda()
        x_2 = torch.rand(127, 63, 7, 7).cuda()
        quantizer1(x_1)
        quantizer1(x_2)
        quantizer1.disable_calib()

        global_amax = torch.max(
            quant_utils.reduce_amax(x_1, axis=reduce_axis, keepdims=True),
            quant_utils.reduce_amax(x_2, axis=reduce_axis, keepdims=True))
        test_utils.compare(quantizer1._calibrator.compute_amax(), global_amax, atol=0, rtol=0, ctol=0)

        quantizer1.load_calib_amax()
        test_utils.compare(quantizer1.amax, global_amax, atol=0, rtol=0, ctol=0)

        quant_desc2 = tensor_quant.QuantDescriptor(learn_amax=True)
        quantizer2 = tensor_quantizer.TensorQuantizer(quant_desc2).cuda()
        quantizer2.enable_calib()
        quantizer2(x_1)
        quantizer2(x_2)

        quantizer2.load_calib_amax()
        quantizer2.init_learn_amax()
        test_utils.compare(quantizer2.clip.clip_value_min, -torch.max(global_amax), atol=0, rtol=0, ctol=0)
        test_utils.compare(quantizer2.clip.clip_value_max, torch.max(global_amax), atol=0, rtol=0, ctol=0)
 def test_read_write(self):
     with tempfile.NamedTemporaryFile(suffix='.hdf5') as f:
         filename = Path(f.name)
     write_encoded_batch(filename, self.encoded_batch)
     assert filename.is_file()
     with h5py.File(filename, 'r') as f:
         channels_per_sample = torch.tensor(f['channels_per_sample'])
         elements_per_sample = torch.tensor(f['elements_per_sample'])
         read = read_encoded_quantized_batch(f, channels_per_sample,
                                             elements_per_sample, 0, 3)
     compare(read, self.encoded_batch)
     with h5py.File(filename, 'r') as f:
         channels_per_sample = torch.tensor(f['channels_per_sample'])
         elements_per_sample = torch.tensor(f['elements_per_sample'])
         read = read_encoded_quantized_batch(f, channels_per_sample,
                                             elements_per_sample, 0, 2)
     compare(read, self.encoded_batches[0])
     with h5py.File(filename, 'r') as f:
         channels_per_sample = torch.tensor(f['channels_per_sample'])
         elements_per_sample = torch.tensor(f['elements_per_sample'])
         read = read_encoded_quantized_batch(f, channels_per_sample,
                                             elements_per_sample, 2, 3)
     compare(read, self.encoded_batches[1])
     filename.unlink()
 def test_encode(self):
     encoded = encode_quantized_batch(self.decoded_batch)
     compare(encoded, self.encoded_batch)
 def test_join(self):
     joined = join_batches(self.encoded_batches)
     compare(joined, self.encoded_batch)
 def test_decode(self):
     decoded = decode_quantized_batch(self.encoded_batch)
     compare(decoded, self.decoded_batch)