def test_prepare_dynamic(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.fc = torch.nn.Linear(5, 5)

            def forward(self, x):
                return self.fc(x)

        m = torch.jit.script(M())
        m = prepare_dynamic_script(m, {'': default_dynamic_qconfig})

        # for input of FC for dynamic quant
        assert len(attrs_with_prefix(m, '_observer_')) == 1
        # for weight
        assert len(attrs_with_prefix(m.fc, '_observer_')) == 1
        FileCheck().check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \
                   .check('prim::GetAttr[name="fc"]') \
                   .check('prim::CallMethod') \
                   .check_not('Observer = prim::GetAttr[name="_observer_') \
                   .run(m.graph)
    def test_prepare_dynamic_child_qconfig(self):
        class Sub(torch.nn.Module):
            def __init__(self):
                super(Sub, self).__init__()
                self.fc = torch.nn.Linear(5, 5)

            def forward(self, x):
                return self.fc(x)

        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.conv = torch.nn.Conv2d(3, 5, 3)
                self.sub = Sub()

            def forward(self, x):
                return self.sub(self.conv(x))

        m = torch.jit.script(M())
        # only quantize child module.
        m = prepare_dynamic_script(m, {'sub.fc': default_dynamic_qconfig})

        # input of sub for dynamic quant
        assert len(attrs_with_prefix(m, '_observer_')) == 1
        # not quantized
        assert len(attrs_with_prefix(m.conv, '_observer_')) == 0
        # no observers since we observe in the outer most call site
        assert len(attrs_with_prefix(m.sub, '_observer_')) == 0
        # weight of linear
        assert len(attrs_with_prefix(m.sub.fc, '_observer_')) == 1
        FileCheck().check('prim::GetAttr[name="sub') \
                   .check('prim::CallMethod') \
                   .check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \
                   .check('prim::CallMethod') \
                   .check_not('Observer = prim::GetAttr[name="_observer_') \
                   .run(m.graph)
    def test_prepare_dynamic_lstm(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)

            def forward(self, x):
                return self.lstm(x)

        from torch.quantization.observer import default_dynamic_quant_observer, _MinMaxTensorListObserver
        qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
                                 weight=_MinMaxTensorListObserver)
        m = torch.jit.script(M())
        m = prepare_dynamic_script(m, {'': qconfig})
        assert len(attrs_with_prefix(m.lstm, '_observer_')) == 1
        FileCheck().check('_MinMaxTensorListObserver = prim::GetAttr[name="_observer_0') \
                   .check("aten::lstm") \
                   .check("return") \
                   .run(str(get_module_method(m, 'lstm', 'forward__0').graph))