def test_matching_failure_node_type(self): # verify that matching graphs with non-matching node types fails m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() m2 = nn.Sequential(nn.Linear(1, 1)).eval() mp1 = prepare_fx(m1, {'': torch.quantization.default_qconfig}) mp2 = prepare_fx(m2, {'': torch.quantization.default_qconfig}) with self.assertRaises(GraphMatchingException) as ex: results = get_matching_node_pairs(mp1, mp2)
def compare_weights( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, ) -> Dict[str, Dict[str, torch.Tensor]]: type_a_related_to_b = get_type_a_related_to_b() matched_node_pairs = get_matching_node_pairs(gm_a, gm_b) results = {} for match_name, match in matched_node_pairs.items(): node_a, node_b = match assert node_a.op == node_b.op and \ node_a.op in ('call_function', 'call_module') if node_a.op == 'call_function': # linear # TODO(future PR): other function types a_related_to_linear = node_a.target in (F.linear,) or \ (node_a.target, F.linear) in type_a_related_to_b if a_related_to_linear: weight_a = get_linear_fun_weight(node_a, gm_a) weight_b = get_linear_fun_weight(node_b, gm_b) results[match_name] = { name_a: weight_a, name_b: weight_b, } else: # call_module # for call_module, we need to look up the modules to do the type check assert isinstance(node_a.target, str) mod_a = getattr_from_fqn(gm_a, node_a.target) assert isinstance(node_b.target, str) mod_b = getattr_from_fqn(gm_b, node_b.target) # check that A is one the modules we need # assume B is related (this is done by graph matcher) a_related_to_conv2d_mod = isinstance(mod_a, nn.Conv2d) or \ (type(mod_a), nn.Conv2d) in type_a_related_to_b # TODO(future PR): other module types if a_related_to_conv2d_mod: weight_a = get_conv_mod_weight(mod_a) weight_b = get_conv_mod_weight(mod_b) results[match_name] = { name_a: weight_a, name_b: weight_b, } return results
def test_mobilenet_v2_qat(self): # verify that mobilenetv2 graph is able to be matched import torchvision m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float() mp = prepare_qat_fx(m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')}) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions results = get_matching_node_pairs(mp, mq)
def test_simple_mod(self): m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval() mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_node_pairs(mp, mq) expected_types = {'0': (nn.Conv2d, nnq.Conv2d)} self.assert_types_for_matched_node_pairs(results, expected_types, mp, mq)
def test_simple_mod_multi(self): m = nn.Sequential( nn.Sequential(nn.Conv2d(1, 1, 1), ), nn.Conv2d(1, 1, 1), ).eval() mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions results = get_matching_node_pairs(mp, mq)
def prepare_model_with_stubs( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, logger_cls: Callable, ) -> GraphModule: """ Same thing as prepare_model_outputs, but for an `a_shadows_b` model. TODO(future PR): real docblock """ matched_node_pairs = get_matching_node_pairs(gm_a, gm_b) gm_a_shadows_b = create_a_shadows_b( name_a, gm_a, name_b, gm_b, matched_node_pairs, logger_cls) return gm_a_shadows_b
def test_simple_tensor_ops(self): class M(nn.Module): def __init__(self): super().__init__() def forward(self, x, y): z = x + y return z m = M().eval() mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) # assume success if no exceptions results = get_matching_node_pairs(mp, mq)
def prepare_model_outputs( name_a: str, gm_a: GraphModule, name_b: str, gm_b: GraphModule, logger_cls: Callable, ) -> Tuple[GraphModule, GraphModule]: matched_node_pairs = get_matching_node_pairs(gm_a, gm_b) nodes_to_instrument_a = [] nodes_to_instrument_b = [] for match_name, (node_a, node_b,) in matched_node_pairs.items(): # TODO(future PR): do not observe pairs of nodes we do not care # about (both fp32, denylist, etc) nodes_to_instrument_a.append(node_a) nodes_to_instrument_b.append(node_b) gm_a = remove_observers_add_loggers(gm_a, nodes_to_instrument_a, logger_cls, name_a) gm_b = remove_observers_add_loggers(gm_b, nodes_to_instrument_b, logger_cls, name_b) return (gm_a, gm_b)
def test_simple_fun(self): class M(nn.Module): def __init__(self): super().__init__() self.w = nn.Parameter(torch.Tensor(1, 4)) self.b = nn.Parameter(torch.Tensor(1)) torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5)) def forward(self, x): return F.linear(x, self.w, self.b) m = M().eval() mp = prepare_fx(m, {'': torch.quantization.default_qconfig}) # TODO(future PR): prevent the need for copying here, we can copy the # modules but should reuse the underlying tensors mp_copy = copy.deepcopy(mp) mq = convert_fx(mp_copy) results = get_matching_node_pairs(mp, mq) expected_types = {'linear_1': (F.linear, toq.linear)} self.assert_types_for_matched_node_pairs(results, expected_types, mp, mq)