예제 #1
0
    def test_calibrator_caches_without_explicit_cache(self, identity_builder_network):
        builder, network = identity_builder_network
        data = [{"x": np.ones((1, 1, 2, 2), dtype=np.float32)}]

        calibrator = Calibrator(data)
        # First, populate the cache
        create_config = CreateConfig(int8=True, calibrator=calibrator)
        with EngineFromNetwork((builder, network), create_config)():
            pass

        # Check that the internal cache is populated
        assert calibrator.read_calibration_cache()
예제 #2
0
    def test_calibrator_rechecks_cache_on_reset(self, identity_builder_network):
        builder, network = identity_builder_network
        data = [{"x": np.ones((1, 1, 2, 2), dtype=np.float32)}]

        with tempfile.NamedTemporaryFile(mode="wb+") as cache:
            calibrator = Calibrator(data, cache=cache.name)
            # First, populate the cache
            create_config = CreateConfig(int8=True, calibrator=calibrator)
            with EngineFromNetwork((builder, network), create_config)():
                pass

            # Ensure that now the calibrator will read from the cache when reset
            calibrator.reset()
            assert not calibrator.has_cached_scales
            assert len(calibrator.read_calibration_cache()) == os.stat(cache.name).st_size