Пример #1
0
    def test_histogram_observer(self, qdtype, qscheme, reduce_range):
        myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
        x = torch.tensor([2.0, 3.0, 4.0, 5.0])
        y = torch.tensor([5.0, 6.0, 7.0, 8.0])
        myobs(x)
        myobs(y)
        self.assertEqual(myobs.min_val, 2.0)
        self.assertEqual(myobs.max_val, 8.0)
        self.assertEqual(myobs.histogram, [2., 3., 3.])

        qparams = myobs.calculate_qparams()

        if reduce_range:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.0470588 * 255 / 127
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0235294 * 255 / 127
                ref_zero_point = -64 if qdtype is torch.qint8 else 0
        else:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.0470588
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0235294
                ref_zero_point = -128 if qdtype is torch.qint8 else 0

        self.assertEqual(qparams[1].item(), ref_zero_point)
        self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
Пример #2
0
 def test_histogram_observer_one_sided(self):
     myobs = HistogramObserver(bins=8, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
     x = torch.tensor([0.0, 0.3, 1.2, 1.7])
     y = torch.tensor([0.1, 1.3, 2.0, 2.7])
     myobs(x)
     myobs(y)
     self.assertEqual(myobs.min_val, 0)
     qparams = myobs.calculate_qparams()
     self.assertEqual(qparams[1].item(), 0)
Пример #3
0
    def test_histogram_observer_against_reference(self, N, bins, dtype, qscheme, reduce_range):

        ref_obs = _ReferenceHistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)
        my_obs = HistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)

        for _ in range(10):
            X = torch.randn(N)
            my_obs(X)
            ref_obs(X)

        ref_qparams = ref_obs.calculate_qparams()
        my_qparams = my_obs.calculate_qparams()

        self.assertEqual(ref_qparams, my_qparams)
Пример #4
0
 def test_histogram_observer_same_inputs(self):
     myobs = HistogramObserver(bins=3, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
     w = torch.ones(4, requires_grad=True)
     x = torch.zeros(4, requires_grad=True)
     y = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
     z = torch.tensor([5.0, 6.0, 7.0, 8.0])
     myobs(w)
     myobs(x)
     myobs(x)
     myobs(y)
     myobs(z)
     qparams = myobs.calculate_qparams()
     self.assertEqual(myobs.min_val, 2.0)
     self.assertEqual(myobs.max_val, 8.0)
     self.assertEqual(myobs.histogram, [2., 3., 3.])
Пример #5
0
def quantize_statically(model, inputs, data_loader, linear_only=False):
    if (hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder)
            and linear_only):
        qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False),
            weight=default_weight_observer,
        )
        qconfig_dict = {"": None}
        for layer_idx in range(len(model.encoder.encoder.transformer.layers)):
            qconfig_dict[
                "encoder.encoder.transformer.layers.{}.attention.input_projection"
                .format(layer_idx)] = qconfig
            qconfig_dict[
                "encoder.encoder.transformer.layers.{}.attention.output_projection"
                .format(layer_idx)] = qconfig
            for mlp_idx, m in enumerate(model.encoder.encoder.transformer.
                                        layers[layer_idx].residual_mlp.mlp):
                if type(m) == torch.nn.Linear:
                    qconfig_dict[
                        "encoder.encoder.transformer.layers.{}.residual_mlp.mlp.{}"
                        .format(layer_idx, mlp_idx)] = qconfig
        trace = model.graph_mode_quantize(inputs,
                                          data_loader,
                                          qconfig_dict=qconfig_dict,
                                          force_quantize=True)
    else:
        trace = model.graph_mode_quantize(inputs, data_loader)

    return trace
Пример #6
0
 def test_histogram_observer_save_load_state_dict(self):
     """
     Smoke test on saving/loading state_dict
     """
     obs1 = HistogramObserver()
     obs1(torch.randn(4, 4, 4, 4))
     obs2 = HistogramObserver()
     obs2.load_state_dict(obs1.state_dict())
     self.assertEqual(obs2.min_val.shape, torch.Size([]))
     self.assertEqual(obs2.max_val.shape, torch.Size([]))
Пример #7
0
 def test_histogram_observer_consistent_buffer_shape(self):
     """
     Ensures that the buffer shapes do not change from uninitialized to
     initialized states for HistogramObserver.
     """
     obs = HistogramObserver()
     min_shape_before = obs.min_val.shape
     max_shape_before = obs.max_val.shape
     for _ in range(2):
         obs(torch.randn(4, 4, 4, 4))
     self.assertEqual(min_shape_before, obs.min_val.shape)
     self.assertEqual(max_shape_before, obs.max_val.shape)
Пример #8
0
def quantize_fx(model, inputs, data_loader, dynamic=True):

    if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder):

        static = not dynamic

        if dynamic:
            qconfig = per_channel_dynamic_qconfig
        else:
            qconfig = QConfig(
                activation=HistogramObserver.with_args(reduce_range=False),
                weight=default_weight_observer,
            )

        # Only linear layers
        qconfig_dict = {"": None}
        qconfig_dict["object_type"] = [(torch.nn.Linear, qconfig)]

        def calibrate(model, loader, max_samples=-1):
            model.eval()
            with torch.no_grad():
                for (idx, d) in enumerate(loader):
                    print("Running sample input #" + str(idx))
                    model(d[1]["tokens"])
                    if idx == max_samples:
                        break

        prepared_model = prepare_fx(
            model.encoder.encoder.transformer.layers.layers,
            qconfig_dict)  # fuse modules and insert observers

        model.encoder.encoder.transformer.layers.layers = prepared_model
        if static:
            calibrate(model, data_loader)  # run calibration on sample data
        model.encoder.encoder.transformer.layers.layers = convert_fx(
            prepared_model)

        # Trace the submodule in order to fix the interface
        if static:
            input1 = torch.randn([2, 1, 1024], dtype=torch.float)
            input2 = torch.randn([1, 2]).bool()
            traced = torch.jit.trace(
                model.encoder.encoder.transformer.layers.layers,
                (input1, input2))
            model.encoder.encoder.transformer.layers.layers = traced

        # Trace the overall module
        trace = model.trace(inputs)

        return trace
Пример #9
0
class ObserverTest(QuantizationTestCase):
    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
           qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
           reduce_range=st.booleans())
    def test_minmax_observer(self, qdtype, qscheme, reduce_range):
        # reduce_range cannot be true for symmetric quantization with uint8
        if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric:
            reduce_range = False
        myobs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
        x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
        y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
        result = myobs(x)
        result = myobs(y)
        self.assertEqual(result, y)
        self.assertEqual(myobs.min_val, 1.0)
        self.assertEqual(myobs.max_val, 8.0)
        qparams = myobs.calculate_qparams()
        if reduce_range:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.062745 * 255 / 127
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0313725 * 255 / 127
                ref_zero_point = -64 if qdtype is torch.qint8 else 0
        else:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.062745
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0313725
                ref_zero_point = -128 if qdtype is torch.qint8 else 0
        self.assertEqual(qparams[1].item(), ref_zero_point)
        self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)

    @given(obs=st.sampled_from((torch.quantization.default_observer()(), HistogramObserver(bins=10))))
    def test_observer_scriptable(self, obs):
        scripted = torch.jit.script(obs)

        x = torch.rand(3, 4)
        obs(x)
        scripted(x)

        self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())

        buf = io.BytesIO()
        torch.jit.save(scripted, buf)
        buf.seek(0)
        loaded = torch.jit.load(buf)
        self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
Пример #10
0
    def test_observer_scriptable(self, qdtype, qscheme):
        ob_list = [
            HistogramObserver(dtype=qdtype, qscheme=qscheme),
            default_histogram_observer()
        ]
        for obs in ob_list:
            scripted = torch.jit.script(obs)

            x = torch.rand(3, 4)
            obs(x)
            scripted(x)
            self.assertTrue(torch.equal(obs.histogram, scripted.histogram))
            buf = io.BytesIO()
            torch.jit.save(scripted, buf)
            buf.seek(0)
            loaded = torch.jit.load(buf)
            self.assertTrue(torch.equal(obs.histogram, scripted.histogram))
Пример #11
0
def quantize_statically(model,
                        inputs,
                        data_loader,
                        linear_only=False,
                        module_swap=False):
    log_feature_usage("export.quantize.statically")
    if (hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder)
            and linear_only):
        log_accelerator_feature_usage("quantize.statically")
        qconfig = QConfig(
            activation=HistogramObserver.with_args(reduce_range=False),
            weight=default_weight_observer,
        )
        qconfig_dict = {"": None}
        if module_swap:
            layers = model.encoder.encoder.transformer.layers.layers
            layers_str = "encoder.encoder.transformer.layers.layers"
        else:
            layers = model.encoder.encoder.transformer.layers
            layers_str = "encoder.encoder.transformer.layers"

        # skip first layer
        for layer_idx in range(1, len(layers)):
            qconfig_dict[
                layers_str +
                ".{}.attention.input_projection".format(layer_idx)] = qconfig
            qconfig_dict[
                layers_str +
                ".{}.attention.output_projection".format(layer_idx)] = qconfig
            for mlp_idx, m in enumerate(layers[layer_idx].residual_mlp.mlp):
                # Only quantize first linear otherwise there are accuarcy issues
                if type(m) == torch.nn.Linear and mlp_idx < 1:
                    qconfig_dict[layers_str + ".{}.residual_mlp.mlp.{}".format(
                        layer_idx, mlp_idx)] = qconfig
        trace = model.graph_mode_quantize(inputs,
                                          data_loader,
                                          qconfig_dict=qconfig_dict,
                                          force_quantize=True)
    else:
        trace = model.graph_mode_quantize(inputs, data_loader)

    return trace
Пример #12
0
    def test_histogram_observer(self, qdtype, qscheme, reduce_range):
        myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
        x = torch.tensor([2.0, 3.0, 4.0, 5.0])
        y = torch.tensor([5.0, 6.0, 7.0, 8.0])
        myobs(x)
        myobs(y)
        self.assertEqual(myobs.min_val, 2.0)
        self.assertEqual(myobs.max_val, 8.0)
        self.assertEqual(myobs.histogram, [2., 3., 3.])

        qparams = myobs.calculate_qparams()

        if reduce_range:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.0470588 * 255 / 127
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0235294 * 255 / 127
                ref_zero_point = -64 if qdtype is torch.qint8 else 0
        else:
            if qscheme == torch.per_tensor_symmetric:
                ref_scale = 0.0470588
                ref_zero_point = 0 if qdtype is torch.qint8 else 128
            else:
                ref_scale = 0.0235294
                ref_zero_point = -128 if qdtype is torch.qint8 else 0

        self.assertEqual(qparams[1].item(), ref_zero_point)
        self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
        # Test for serializability
        state_dict = myobs.state_dict()
        b = io.BytesIO()
        torch.save(state_dict, b)
        b.seek(0)
        loaded_dict = torch.load(b)
        for key in state_dict:
            self.assertEqual(state_dict[key], loaded_dict[key])
        loaded_obs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
        loaded_obs.load_state_dict(loaded_dict)
        loaded_qparams = loaded_obs.calculate_qparams()
        self.assertEqual(myobs.min_val, loaded_obs.min_val)
        self.assertEqual(myobs.max_val, loaded_obs.max_val)
        self.assertEqual(myobs.histogram, loaded_obs.histogram)
        self.assertEqual(myobs.bins, loaded_obs.bins)
        self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
Пример #13
0
def quantize_fx(model, inputs, data_loader, dynamic=True, selective=False):

    if hasattr(model, "encoder") and isinstance(model.encoder, RoBERTaEncoder):

        static = not dynamic

        if dynamic:
            qconfig = per_channel_dynamic_qconfig
        else:
            qconfig = QConfig(
                activation=HistogramObserver.with_args(reduce_range=False),
                weight=default_weight_observer,
            )

        # Only linear layers
        qconfig_dict = {"": None}
        if static and selective:
            qconfig_dict["module_name"] = []
            layers = model.encoder.encoder.transformer.layers.layers.layers
            layers_str = "layers"
            # skip first layer
            for layer_idx in range(1, len(layers)):
                qconfig_dict["module_name"].append((
                    layers_str +
                    ".{}.attention.input_projection".format(layer_idx),
                    qconfig,
                ))
                qconfig_dict["module_name"].append((
                    layers_str +
                    ".{}.attention.output_projection".format(layer_idx),
                    qconfig,
                ))
                for mlp_idx, m in enumerate(
                        layers[layer_idx].residual_mlp.mlp):
                    # Only quantize first linear otherwise there are accuarcy issues with static quantization
                    if type(m) == torch.nn.Linear and mlp_idx < 1:
                        qconfig_dict["module_name"].append((
                            layers_str + ".{}.residual_mlp.mlp.{}".format(
                                layer_idx, mlp_idx),
                            qconfig,
                        ))
        else:
            qconfig_dict["object_type"] = [(torch.nn.Linear, qconfig)]

        def calibrate(model, loader, max_samples=-1):
            model.eval()
            with torch.no_grad():
                for (idx, d) in enumerate(loader):
                    print("Running sample input #" + str(idx))
                    model(d[1]["tokens"])
                    if idx == max_samples:
                        break

        prepared_model = prepare_fx(
            model.encoder.encoder.transformer.layers.layers,
            qconfig_dict)  # fuse modules and insert observers

        model.encoder.encoder.transformer.layers.layers = prepared_model
        if static:
            calibrate(model, data_loader)  # run calibration on sample data
        model.encoder.encoder.transformer.layers.layers = convert_fx(
            prepared_model)

        # Trace the submodule in order to fix the interface
        if static:
            input1 = torch.randn([2, 1, 1024], dtype=torch.float)
            input2 = torch.randn([1, 2]).bool()
            traced = torch.jit.trace(
                model.encoder.encoder.transformer.layers.layers,
                (input1, input2))
            model.encoder.encoder.transformer.layers.layers = traced

        # Trace the overall module
        trace = model.trace(inputs)

        return trace