示例#1
0
    def test_sparsenn_compare_activations(self):
        sparse_nn = SparseNNModel().eval()

        # quantize the embeddings and the dense part separately, using FX graph mode
        sparse_nn.dense_top = prepare_fx(
            sparse_nn.dense_top,
            {'': torch.quantization.default_qconfig},
        )

        # calibrate
        idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
        offsets = torch.LongTensor([0, 4])
        x = torch.randn(2, 4)
        sparse_nn(idx, offsets, x)

        # convert
        sparse_nn_q = copy.deepcopy(sparse_nn)
        sparse_nn_q.dense_top = convert_fx(sparse_nn_q.dense_top)

        # test out compare activations API
        sparse_nn.dense_top, sparse_nn_q.dense_top = prepare_model_outputs(
            'fp32_prepared', sparse_nn.dense_top, 'int8', sparse_nn_q.dense_top, OutputLogger)

        # calibrate
        sparse_nn(idx, offsets, x)
        sparse_nn_q(idx, offsets, x)

        # inspect results
        act_compare_dict = get_matching_activations(
            'fp32_prepared', sparse_nn, 'int8', sparse_nn_q, OutputLogger)
        self.assertTrue(len(act_compare_dict) == 4)
        self.assert_ns_compare_dict_valid(act_compare_dict)
示例#2
0
    def test_match_activations_fun(self):
        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.w1 = nn.Parameter(torch.Tensor(4, 4))
                self.b1 = nn.Parameter(torch.zeros(4))
                self.w2 = nn.Parameter(torch.Tensor(4, 4))
                self.b2 = nn.Parameter(torch.zeros(4))
                torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
                torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))

            def forward(self, x):
                x = F.linear(x, self.w1, self.b1)
                x = F.linear(x, self.w2, self.b2)
                x = F.relu(x)
                return x

        m = M().eval()
        mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
        mp(torch.randn(4, 4))
        # TODO(future PR): prevent the need for copying here, we can copy the
        # modules but should reuse the underlying tensors
        mp_copy = copy.deepcopy(mp)
        mq = convert_fx(mp_copy)

        mp_ns, mq_ns = prepare_model_outputs('fp32_prepared', mp, 'int8', mq,
                                             OutputLogger)

        expected_occurrence = {
            ns.call_module(OutputLogger): 2,
        }
        self.checkGraphModuleNodes(
            mp_ns, expected_node_occurrence=expected_occurrence)
        self.checkGraphModuleNodes(
            mq_ns, expected_node_occurrence=expected_occurrence)

        # TODO(before land): test both scripted and non-scripted
        mp_ns = torch.jit.script(mp_ns)
        mq_ns = torch.jit.script(mq_ns)

        # calibrate
        input_fp32 = torch.randn(4, 4)
        mp_ns(input_fp32)
        mq_ns(input_fp32)

        # check activation result correctness
        act_compare_dict = get_matching_activations('fp32_prepared', mp_ns,
                                                    'int8', mq_ns,
                                                    OutputLogger)
        self.assertTrue(len(act_compare_dict) == 2)
        self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
示例#3
0
    def test_match_activations_mod(self):
        m = nn.Sequential(
            torch.quantization.QuantStub(),
            nn.Conv2d(1, 1, 1),
            nn.Conv2d(1, 1, 1),
        ).eval()
        mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
        mp(torch.randn(2, 1, 2, 2))
        # TODO(future PR): prevent the need for copying here, we can copy the
        # modules but should reuse the underlying tensors
        mp_copy = copy.deepcopy(mp)
        mq = convert_fx(mp_copy)

        mp_ns, mq_ns = prepare_model_outputs('fp32_prepared', mp, 'int8', mq,
                                             OutputLogger)

        expected_occurrence = {
            ns.call_module(OutputLogger): 2,
        }
        self.checkGraphModuleNodes(
            mp_ns, expected_node_occurrence=expected_occurrence)
        self.checkGraphModuleNodes(
            mq_ns, expected_node_occurrence=expected_occurrence)

        # TODO(before land): test both scripted and non-scripted
        mp_ns = torch.jit.script(mp_ns)
        mq_ns = torch.jit.script(mq_ns)

        # calibrate
        input_fp32 = torch.randn(2, 1, 2, 2)
        mp_ns(input_fp32)
        mq_ns(input_fp32)

        # check activation result correctness
        act_compare_dict = get_matching_activations('fp32_prepared', mp_ns,
                                                    'int8', mq_ns,
                                                    OutputLogger)
        self.assertTrue(len(act_compare_dict) == 2)
        self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)