コード例 #1
0
 def test_matching_failure_node_type(self):
     # verify that matching graphs with non-matching node types fails
     m1 = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
     m2 = nn.Sequential(nn.Linear(1, 1)).eval()
     mp1 = prepare_fx(m1, {'': torch.quantization.default_qconfig})
     mp2 = prepare_fx(m2, {'': torch.quantization.default_qconfig})
     with self.assertRaises(GraphMatchingException) as ex:
         results = get_matching_node_pairs(mp1, mp2)
コード例 #2
0
def compare_weights(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
) -> Dict[str, Dict[str, torch.Tensor]]:
    type_a_related_to_b = get_type_a_related_to_b()
    matched_node_pairs = get_matching_node_pairs(gm_a, gm_b)

    results = {}

    for match_name, match in matched_node_pairs.items():

        node_a, node_b = match
        assert node_a.op == node_b.op and \
            node_a.op in ('call_function', 'call_module')

        if node_a.op == 'call_function':

            # linear
            # TODO(future PR): other function types
            a_related_to_linear = node_a.target in (F.linear,) or \
                (node_a.target, F.linear) in type_a_related_to_b

            if a_related_to_linear:
                weight_a = get_linear_fun_weight(node_a, gm_a)
                weight_b = get_linear_fun_weight(node_b, gm_b)

                results[match_name] = {
                    name_a: weight_a,
                    name_b: weight_b,
                }

        else:  # call_module
            # for call_module, we need to look up the modules to do the type check
            assert isinstance(node_a.target, str)
            mod_a = getattr_from_fqn(gm_a, node_a.target)
            assert isinstance(node_b.target, str)
            mod_b = getattr_from_fqn(gm_b, node_b.target)

            # check that A is one the modules we need
            # assume B is related (this is done by graph matcher)
            a_related_to_conv2d_mod = isinstance(mod_a, nn.Conv2d) or \
                (type(mod_a), nn.Conv2d) in type_a_related_to_b

            # TODO(future PR): other module types
            if a_related_to_conv2d_mod:
                weight_a = get_conv_mod_weight(mod_a)
                weight_b = get_conv_mod_weight(mod_b)
                results[match_name] = {
                    name_a: weight_a,
                    name_b: weight_b,
                }

    return results
コード例 #3
0
 def test_mobilenet_v2_qat(self):
     # verify that mobilenetv2 graph is able to be matched
     import torchvision
     m = torchvision.models.__dict__['mobilenet_v2'](pretrained=False).float()
     mp = prepare_qat_fx(m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')})
     # TODO(future PR): prevent the need for copying here, we can copy the
     # modules but should reuse the underlying tensors
     mp_copy = copy.deepcopy(mp)
     mq = convert_fx(mp_copy)
     # assume success if no exceptions
     results = get_matching_node_pairs(mp, mq)
コード例 #4
0
    def test_simple_mod(self):
        m = nn.Sequential(nn.Conv2d(1, 1, 1)).eval()
        mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
        # TODO(future PR): prevent the need for copying here, we can copy the
        # modules but should reuse the underlying tensors
        mp_copy = copy.deepcopy(mp)
        mq = convert_fx(mp_copy)
        results = get_matching_node_pairs(mp, mq)

        expected_types = {'0': (nn.Conv2d, nnq.Conv2d)}
        self.assert_types_for_matched_node_pairs(results, expected_types, mp, mq)
コード例 #5
0
 def test_simple_mod_multi(self):
     m = nn.Sequential(
         nn.Sequential(nn.Conv2d(1, 1, 1), ),
         nn.Conv2d(1, 1, 1),
     ).eval()
     mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
     # TODO(future PR): prevent the need for copying here, we can copy the
     # modules but should reuse the underlying tensors
     mp_copy = copy.deepcopy(mp)
     mq = convert_fx(mp_copy)
     # assume success if no exceptions
     results = get_matching_node_pairs(mp, mq)
コード例 #6
0
def prepare_model_with_stubs(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    logger_cls: Callable,
) -> GraphModule:
    """
    Same thing as prepare_model_outputs, but for an `a_shadows_b` model.
    TODO(future PR): real docblock
    """
    matched_node_pairs = get_matching_node_pairs(gm_a, gm_b)
    gm_a_shadows_b = create_a_shadows_b(
        name_a, gm_a, name_b, gm_b, matched_node_pairs, logger_cls)
    return gm_a_shadows_b
コード例 #7
0
    def test_simple_tensor_ops(self):
        class M(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, y):
                z = x + y
                return z

        m = M().eval()
        mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
        # TODO(future PR): prevent the need for copying here, we can copy the
        # modules but should reuse the underlying tensors
        mp_copy = copy.deepcopy(mp)
        mq = convert_fx(mp_copy)
        # assume success if no exceptions
        results = get_matching_node_pairs(mp, mq)
コード例 #8
0
def prepare_model_outputs(
    name_a: str,
    gm_a: GraphModule,
    name_b: str,
    gm_b: GraphModule,
    logger_cls: Callable,
) -> Tuple[GraphModule, GraphModule]:

    matched_node_pairs = get_matching_node_pairs(gm_a, gm_b)

    nodes_to_instrument_a = []
    nodes_to_instrument_b = []
    for match_name, (node_a, node_b,) in matched_node_pairs.items():
        # TODO(future PR): do not observe pairs of nodes we do not care
        #   about (both fp32, denylist, etc)
        nodes_to_instrument_a.append(node_a)
        nodes_to_instrument_b.append(node_b)

    gm_a = remove_observers_add_loggers(gm_a, nodes_to_instrument_a, logger_cls, name_a)
    gm_b = remove_observers_add_loggers(gm_b, nodes_to_instrument_b, logger_cls, name_b)
    return (gm_a, gm_b)
コード例 #9
0
    def test_simple_fun(self):
        class M(nn.Module):
            def __init__(self):
                super().__init__()
                self.w = nn.Parameter(torch.Tensor(1, 4))
                self.b = nn.Parameter(torch.Tensor(1))
                torch.nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))

            def forward(self, x):
                return F.linear(x, self.w, self.b)

        m = M().eval()
        mp = prepare_fx(m, {'': torch.quantization.default_qconfig})
        # TODO(future PR): prevent the need for copying here, we can copy the
        # modules but should reuse the underlying tensors
        mp_copy = copy.deepcopy(mp)
        mq = convert_fx(mp_copy)
        results = get_matching_node_pairs(mp, mq)

        expected_types = {'linear_1': (F.linear, toq.linear)}
        self.assert_types_for_matched_node_pairs(results, expected_types, mp, mq)