def test_sparsenn_compare_activations(self): sparse_nn = SparseNNModel().eval() # quantize the embeddings and the dense part separately, using FX graph mode sparse_nn.dense_top = prepare_fx( sparse_nn.dense_top, {'': torch.quantization.default_qconfig}, ) # calibrate idx = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) offsets = torch.LongTensor([0, 4]) x = torch.randn(2, 4) sparse_nn(idx, offsets, x) # convert sparse_nn_q = copy.deepcopy(sparse_nn) sparse_nn_q.dense_top = convert_fx(sparse_nn_q.dense_top) # test out compare activations API sparse_nn.dense_top, sparse_nn_q.dense_top = prepare_model_outputs( 'fp32_prepared', sparse_nn.dense_top, 'int8', sparse_nn_q.dense_top, OutputLogger) # calibrate sparse_nn(idx, offsets, x) sparse_nn_q(idx, offsets, x) # inspect results act_compare_dict = get_matching_activations( 'fp32_prepared', sparse_nn, 'int8', sparse_nn_q, OutputLogger) self.assertTrue(len(act_compare_dict) == 4) self.assert_ns_compare_dict_valid(act_compare_dict)
def test_match_activations_fun(self): class M(nn.Module): def __init__(self): super().__init__() self.w1 = nn.Parameter(torch.Tensor(4, 4)) self.b1 = nn.Parameter(torch.zeros(4)) self.w2 = nn.Parameter(torch.Tensor(4, 4)) self.b2 = nn.Parameter(torch.zeros(4)) torch.nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5)) def forward(self, x): x = F.linear(x, self.w1, self.b1) x = F.linear(x, self.w2, self.b2) x = F.relu(x) return x m = M().eval() mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) mp(torch.randn(4, 4)) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) mp_ns, mq_ns = prepare_model_outputs('fp32_prepared', mp, 'int8', mq, OutputLogger) expected_occurrence = { ns.call_module(OutputLogger): 2, } self.checkGraphModuleNodes( mp_ns, expected_node_occurrence=expected_occurrence) self.checkGraphModuleNodes( mq_ns, expected_node_occurrence=expected_occurrence) # TODO(before land): test both scripted and non-scripted mp_ns = torch.jit.script(mp_ns) mq_ns = torch.jit.script(mq_ns) # calibrate input_fp32 = torch.randn(4, 4) mp_ns(input_fp32) mq_ns(input_fp32) # check activation result correctness act_compare_dict = get_matching_activations('fp32_prepared', mp_ns, 'int8', mq_ns, OutputLogger) self.assertTrue(len(act_compare_dict) == 2) self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)
def test_match_activations_mod(self): m = nn.Sequential( torch.quantization.QuantStub(), nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1), ).eval() mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) mp(torch.randn(2, 1, 2, 2)) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) mp_ns, mq_ns = prepare_model_outputs('fp32_prepared', mp, 'int8', mq, OutputLogger) expected_occurrence = { ns.call_module(OutputLogger): 2, } self.checkGraphModuleNodes( mp_ns, expected_node_occurrence=expected_occurrence) self.checkGraphModuleNodes( mq_ns, expected_node_occurrence=expected_occurrence) # TODO(before land): test both scripted and non-scripted mp_ns = torch.jit.script(mp_ns) mq_ns = torch.jit.script(mq_ns) # calibrate input_fp32 = torch.randn(2, 1, 2, 2) mp_ns(input_fp32) mq_ns(input_fp32) # check activation result correctness act_compare_dict = get_matching_activations('fp32_prepared', mp_ns, 'int8', mq_ns, OutputLogger) self.assertTrue(len(act_compare_dict) == 2) self.assert_ns_logger_act_compare_dict_valid(act_compare_dict)