Ejemplo n.º 1
0
    def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False,
                         check=True, eval_mode=True, dynamic=False, qconfig=None):
        if debug:
            print('Testing:', str(module))
        qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}

        if eval_mode:
            module = module.eval()
        if dynamic:
            qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig}
        model = get_script_module(module, tracing, inputs[0]).eval()
        if debug:
            print('input graph:', model.graph)
        models = {}
        outputs = {}
        for d in [True, False]:
            if dynamic:
                models[d] = quantize_dynamic_jit(model, qconfig_dict, debug=d)
                # make sure it runs
                outputs[d] = models[d](inputs)
            else:
                # module under test can contain in-place ops, and we depend on
                # input data staying constant for comparisons
                inputs_copy = copy.deepcopy(inputs)
                models[d] = quantize_jit(
                    model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False,
                    debug=d)
                # make sure it runs
                outputs[d] = models[d](*inputs[0])

        if debug:
            print('debug graph:', models[True].graph)
            print('non debug graph:', models[False].graph)

        if check:
            # debug and non-debug option should have the same numerics
            self.assertEqual(outputs[True], outputs[False])

            # non debug graph should produce quantized op
            FileCheck().check(quantized_op) \
                       .run(models[False].graph)

        return models[False]
Ejemplo n.º 2
0
    def test_list_sort(self):
        template = dedent('''
        def func():
            li_1 = {list_create}
            li_2 = {list_create}
            li_3 = {list_create}
            li_1.sort()
            li_2.sort(reverse=True)
            li_4 = sorted(li_3)
            return li_1, li_2, li_3, li_4
        ''')

        lists = [
            "[]", "[1, 3, 2]", "[True, False, True]", "[1.2, .2, 3.2]",
            "[torch.tensor(1.0), torch.tensor(0.2), torch.tensor(0.5)]",
            "[torch.tensor(5), torch.tensor(-2), torch.tensor(4)]"
        ]
        for li in lists:
            code = template.format(list_create=li)
            scope = {}
            exec(code, globals(), scope)
            cu = torch.jit.CompilationUnit(code)
            t1 = cu.func()
            t2 = scope['func']()
            self.assertEqual(t1, t2)

        def test_fail(x):
            # type: (List[Tensor]) -> List[Tensor]
            x.sort()
            return x

        self.checkScriptRaisesRegex(
            test_fail, (([torch.zeros([2]), torch.zeros([2])], )), Exception,
            "bool value of Tensor with more than one value")

        @torch.jit.script
        def test_mutation():
            a = [1, 2, 3]
            a.sort()
            return a

        test_mutation()
        FileCheck().check("aten::sort").run(test_mutation.graph_for())
Ejemplo n.º 3
0
    def test_warn_only_once_in_loop_func(self):
        def w():
            warnings.warn("I am warning you")

        @torch.jit.script
        def fn():
            for _ in range(10):
                w()

        f = io.StringIO()
        with redirect_stderr(f):
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=1,
                exactly=True) \
            .run(f.getvalue())
Ejemplo n.º 4
0
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(forward_graph,
                                        'prim::FusionGroup',
                                        1,
                                        consider_subgraphs=True)
        self.assertTrue(len(list(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        hy, cy = module(*inputs)
        (hy + cy).sum().backward()
        backward = backward_graph(module)
        self.assertAllFused(backward,
                            except_for=("aten::t", "aten::mm",
                                        "aten::_grad_sum_to_size"))
Ejemplo n.º 5
0
    def test_prepare_dynamic_lstm(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)

            def forward(self, x):
                return self.lstm(x)

        from torch.quantization.observer import default_dynamic_quant_observer, _MinMaxTensorListObserver
        qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
                                 weight=_MinMaxTensorListObserver)
        m = torch.jit.script(M())
        m = prepare_dynamic_script(m, {'': qconfig})
        assert len(attrs_with_prefix(m.lstm, '_observer_')) == 1
        FileCheck().check('_MinMaxTensorListObserver = prim::GetAttr[name="_observer_0') \
                   .check("aten::lstm") \
                   .check("return") \
                   .run(str(get_module_method(m, 'lstm', 'forward__0').graph))
Ejemplo n.º 6
0
    def test_enum_as_const(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        @torch.jit.script
        def enum_const(x: Color) -> bool:
            return x == Color.RED

        FileCheck() \
            .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \
            .check_next("aten::eq") \
            .check_next("return") \
            .run(str(enum_const.graph))

        self.assertEqual(enum_const(Color.RED), True)
        self.assertEqual(enum_const(Color.GREEN), False)
Ejemplo n.º 7
0
    def test_enum_ivalue_type(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        @torch.jit.script
        def is_color_enum(x: Any):
            return isinstance(x, Color)

        FileCheck() \
            .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \
            .check_next("return") \
            .run(str(is_color_enum.graph))

        self.assertEqual(is_color_enum(Color.RED), True)
        self.assertEqual(is_color_enum(Color.GREEN), True)
        self.assertEqual(is_color_enum(1), False)
Ejemplo n.º 8
0
    def test_enum_value(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_value(x: Color) -> int:
            return x.value

        FileCheck() \
            .check("Color") \
            .check_next("prim::EnumValue") \
            .check_next("return") \
            .run(str(enum_value.graph))

        self.assertEqual(enum_value(Color.RED), Color.RED.value)
        self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
Ejemplo n.º 9
0
    def test_enum_name(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        @torch.jit.script
        def enum_name(x: Color) -> str:
            return x.name

        FileCheck() \
            .check("Color") \
            .check_next("prim::EnumName") \
            .check_next("return") \
            .run(str(enum_name.graph))

        self.assertEqual(enum_name(Color.RED), Color.RED.name)
        self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
Ejemplo n.º 10
0
    def test_enum_comp(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        def enum_comp(x: Color, y: Color) -> bool:
            return x == y

        # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
        # is supported.
        with torch._jit_internal._disable_emit_hooks():
            scripted_enum_comp = torch.jit.script(enum_comp)

        FileCheck().check("aten::eq").run(str(scripted_enum_comp.graph))

        self.assertEqual(scripted_enum_comp(Color.RED, Color.RED), True)

        self.assertEqual(scripted_enum_comp(Color.RED, Color.GREEN), False)
Ejemplo n.º 11
0
    def test_ignore_with_types(self):
        @torch.jit.ignore
        def fn(x: Dict[str, Optional[torch.Tensor]]):
            return x + 10

        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()

            def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
                self.dropout_modality(in_batch)
                fn(in_batch)
                return torch.tensor(1)

            @torch.jit.ignore
            def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
                return in_batch

        sm = torch.jit.script(M())
        FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
    def test_shape_analysis(self):
        @torch.jit.script
        def foo(x, y):
            return x * y

        inputs = list(foo.graph.inputs())

        def prop_shapes_on_graph(inp0, inp1):
            inputs[0].setType(inputs[0].type().with_sizes(inp0))
            inputs[1].setType(inputs[1].type().with_sizes(inp1))
            torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)

        prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5])
        FileCheck().check("1, 7, 6, 5").run(foo.graph)

        # None implicitly creates a new symbolic symbol
        prop_shapes_on_graph([None, None], [None, None, None])
        output_shape = foo.graph.findNode(
            "aten::mul").output().type().symbolic_sizes()
        inp0_shape = inputs[0].type().symbolic_sizes()
        inp1_shape = inputs[1].type().symbolic_sizes()

        # output shape dim 0 should be taken from the second inp dim0
        # other two dims we cannot infer and are given a new symbolic shape
        self.assertEqual(output_shape[0], inp1_shape[0])
        self.assertFalse(output_shape[1] in inp0_shape + inp1_shape)
        self.assertFalse(output_shape[2] in inp0_shape + inp1_shape)

        # XXX: symbolic shapes are represented with an increasing counter of unique
        # values, use `_new_symbolic_shape_symbol` api instead of specifying negative
        # dimensions directly so there is no chance of collision between manual number
        # and current counter value.
        sym1 = torch._C._new_symbolic_shape_symbol()
        sym2 = torch._C._new_symbolic_shape_symbol()
        sym3 = torch._C._new_symbolic_shape_symbol()
        prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3])
        output_shape = foo.graph.findNode(
            "aten::mul").output().type().symbolic_sizes()
        self.assertEqual(output_shape[0], sym1)
        self.assertEqual(output_shape[1], sym2)
        self.assertEqual(output_shape[2], sym3)
Ejemplo n.º 13
0
    def test_enum_value_types(self):
        global IntEnum

        class IntEnum(Enum):
            FOO = 1
            BAR = 2

        global FloatEnum

        class FloatEnum(Enum):
            FOO = 1.2
            BAR = 2.3

        global StringEnum

        class StringEnum(Enum):
            FOO = "foo as in foo bar"
            BAR = "bar as in foo bar"

        @torch.jit.script
        def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
            return (a.name, b.name, c.name)

        FileCheck() \
            .check("IntEnum") \
            .check("FloatEnum") \
            .check("StringEnum") \
            .run(str(supported_enum_types.graph))

        global TensorEnum

        class TensorEnum(Enum):
            FOO = torch.tensor(0)
            BAR = torch.tensor(1)

        def unsupported_enum_types(a: TensorEnum):
            return a.name

        with self.assertRaisesRegex(
                RuntimeError, "Cannot create Enum with value type 'Tensor'"):
            torch.jit.script(unsupported_enum_types)
Ejemplo n.º 14
0
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        return
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(forward_graph,
                                        FUSION_GROUP,
                                        1,
                                        consider_subgraphs=True)
        self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        with enable_profiling_mode(True):
            hy, cy = module(*inputs)
            warmup_backward((hy + cy).sum())
            backward = backward_graph(module)
        self.assertAllFused(backward,
                            except_for=("aten::t", "aten::mm",
                                        "aten::_grad_sum_to_size"))
Ejemplo n.º 15
0
    def test_prepare_dynamic(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.fc = torch.nn.Linear(5, 5)

            def forward(self, x):
                return self.fc(x)

        m = torch.jit.script(M())
        m = prepare_dynamic_script(m, {'': default_dynamic_qconfig})

        # for input of FC for dynamic quant
        assert len(attrs_with_prefix(m, '_observer_')) == 1
        # for weight
        assert len(attrs_with_prefix(m.fc, '_observer_')) == 1
        FileCheck().check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \
                   .check('prim::GetAttr[name="fc"]') \
                   .check('prim::CallMethod') \
                   .check_not('Observer = prim::GetAttr[name="_observer_') \
                   .run(m.graph)
Ejemplo n.º 16
0
    def test_list_type_refinement_defaults_to_Any_list_creation(self):
        def fn(x):
            tup1 = ("foo", torch.tensor(2))
            tup2 = ("bar", {"23": torch.tensor(3)})
            tup3 = ("baz", x)
            l = list((tup1, tup2))  # noqa: C410
            l.append(tup3)
            tup4 = l[0]
            if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]):
                t = tup4[1]
                if isinstance(t, torch.Tensor):
                    l[0] = (tup4[0], torch.add(t, t))
            return l

        self.checkScript(fn, (torch.arange(5), ))

        graph = torch.jit.script(fn).graph

        # Check that we're making a `List[Tuple[str, Any]]`
        FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])"
                          "[] = prim::ListConstruct()").run(graph)
Ejemplo n.º 17
0
    def test_get_gradients(self):
        dst_rank = self.rank

        @torch.jit.script
        def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]):
            return dist_autograd.get_gradients(context_id)

        FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
        with dist_autograd.context() as context_id:
            t1 = torch.rand((3, 3), requires_grad=True)
            t2 = torch.rand((3, 3), requires_grad=True)
            t3 = torch.add(t1, t2)

            dist_autograd.backward(context_id, [t3.sum()])
            grads = dist_get_gradients(context_id)

            self.assertEqual(2, len(grads))
            self.assertIn(t1, grads)
            self.assertIn(t2, grads)
            self.assertEqual(torch.ones(3, 3), grads[t1])
            self.assertEqual(torch.ones(3, 3), grads[t2])
Ejemplo n.º 18
0
    def checkGraphModeOp(self, module, data, quantized_op, tracing=False, debug=False, check=True, eval_mode=True, dynamic=False):
        if debug:
            print('Testing:', str(module))
        qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}

        if eval_mode:
            module = module.eval()
        if dynamic:
            qconfig_dict = {'': default_dynamic_qconfig}
            inputs = data
        else:
            *inputs, target = data[0]
        model = get_script_module(module, tracing, inputs).eval()
        models = {}
        outputs = {}
        for d in [True, False]:
            # TODO: _test_only_eval_fn --> default_eval_fn
            if dynamic:
                models[d] = quantize_dynamic_script(model, qconfig_dict, debug=d)
                # make sure it runs
                outputs[d] = models[d](inputs)
            else:
                models[d] = quantize_script(
                    model, qconfig_dict, test_only_eval_fn, [data], inplace=False, debug=d)
                # make sure it runs
                outputs[d] = models[d](*inputs)

        if debug:
            print('debug graph:', models[True].graph)
            print('non debug graph:', models[False].graph)

        if check:
            # debug and non-debug option should have the same numerics
            self.assertEqual(outputs[True], outputs[False])

            # non debug graph should produce quantized op
            FileCheck().check(quantized_op) \
                       .run(models[False].graph)

        return models[False]
Ejemplo n.º 19
0
    def test_error_stack(self):
        def d(x: int) -> int:
            return x + 10

        def c(x):
            return d("hello") + d(x)

        def b(x):
            return c(x)

        def a(x):
            return b(x)

        try:
            scripted = torch.jit.script(a)
        except RuntimeError as e:
            checker = FileCheck()
            checker.check("Expected a value of type 'int'")
            checker.check("def c(x)")
            checker.check("def b(x)")
            checker.check("def a(x)")
            checker.run(str(e))
Ejemplo n.º 20
0
    def test_enum_comp_diff_classes(self):
        class Foo(Enum):
            ITEM1 = 1
            ITEM2 = 2

        class Bar(Enum):
            ITEM1 = 1
            ITEM2 = 2

        make_global(Foo, Bar)

        @torch.jit.script
        def enum_comp(x: Foo) -> bool:
            return x == Bar.ITEM1

        FileCheck() \
            .check("prim::Constant") \
            .check_same("Bar.ITEM1") \
            .check("aten::eq") \
            .run(str(enum_comp.graph))

        self.assertEqual(enum_comp(Foo.ITEM1), False)
Ejemplo n.º 21
0
    def test_enum_value_types(self):
        class IntEnum(Enum):
            FOO = 1
            BAR = 2

        class FloatEnum(Enum):
            FOO = 1.2
            BAR = 2.3

        class StringEnum(Enum):
            FOO = "foo as in foo bar"
            BAR = "bar as in foo bar"

        make_global(IntEnum, FloatEnum, StringEnum)

        @torch.jit.script
        def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
            return (a.name, b.name, c.name)

        FileCheck() \
            .check("IntEnum") \
            .check("FloatEnum") \
            .check("StringEnum") \
            .run(str(supported_enum_types.graph))

        class TensorEnum(Enum):
            FOO = torch.tensor(0)
            BAR = torch.tensor(1)

        make_global(TensorEnum)

        def unsupported_enum_types(a: TensorEnum):
            return a.name

        # TODO: rewrite code so that the highlight is not empty.
        with self.assertRaisesRegexWithHighlight(
                RuntimeError, "Cannot create Enum with value type 'Tensor'",
                ""):
            torch.jit.script(unsupported_enum_types)
Ejemplo n.º 22
0
    def test_tensor_scalar_ops_cuda(self):
        def should_fuse(x):
            z = 3.
            y = x + z
            return x * y

        # XXX: right now we only support fusing scalars if
        # they're constant (#9940)
        def should_not_fuse(x, z):
            y = x + int(z)
            return x * y

        inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
        ge = self.checkScript(should_fuse, inputs)
        graph = ge.graph_for(*inputs)
        fusion_groups = self.findFusionGroups(graph)
        self.assertEqual(len(fusion_groups), 1)
        FileCheck().check("aten::add").check("aten::mul").run(
            str(fusion_groups[0]))

        inputs = [
            torch.randn(2, 2, dtype=torch.float, device='cuda'),
            torch.tensor(3., dtype=torch.float, device='cuda'),
        ]
        ge = self.checkScript(should_not_fuse, inputs)
        # Check that the fused graph computes correct results when the scalar
        # input changes.
        inputs = [
            torch.randn(2, 2, dtype=torch.float, device='cuda'),
            torch.tensor(7., dtype=torch.float, device='cuda'),
        ]
        self.assertEqual(ge(*inputs), should_not_fuse(*inputs))
        # XXX: The TE fuser supports fusion of non-constant scalars
        # self.assertGraphContainsExactly(
        #     ge.graph_for(*inputs), FUSION_GROUP, 0, consider_subgraphs=True)
        self.assertGraphContainsExactly(ge.graph_for(*inputs),
                                        FUSION_GROUP,
                                        0,
                                        consider_subgraphs=True)
Ejemplo n.º 23
0
    def test_warn_once_per_func(self):
        def w1():
            warnings.warn("I am warning you")

        def w2():
            warnings.warn("I am warning you")

        @torch.jit.script
        def fn():
            w1()
            w2()

        f = io.StringIO()
        with redirect_stderr(f):
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=2,
                exactly=True) \
            .run(f.getvalue())
Ejemplo n.º 24
0
    def test_list_type_refinement_defaults_to_Any_list_comprehension(self):
        def fn(x):
            tup1 = ("foo", torch.tensor(2))
            tup2 = ("bar", {"23": torch.tensor(3)})
            tup3 = ("baz", x)
            l_ = [tup1, tup2]
            l = [t for t in l_]  # noqa: C416
            l.append(tup3)
            tup4 = l[0]
            if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]):
                t = tup4[1]
                if isinstance(t, torch.Tensor):
                    l[0] = (tup4[0], torch.add(t, t))
            return l

        self.checkScript(fn, (torch.arange(5), ))

        graph = torch.jit.script(fn).graph

        print(graph)

        # Check that we're making a `List[Tuple[str, Any]]`
        FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph)
Ejemplo n.º 25
0
    def test_enum_value(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        def enum_value(x: Color) -> int:
            return x.value

        # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
        # is supported.
        with torch._jit_internal._disable_emit_hooks():
            scripted_enum_value = torch.jit.script(enum_value)

        FileCheck() \
            .check("Color") \
            .check_next("prim::EnumValue") \
            .check_next("return") \
            .run(str(scripted_enum_value.graph))

        self.assertEqual(scripted_enum_value(Color.RED), Color.RED.value)
        self.assertEqual(scripted_enum_value(Color.GREEN), Color.GREEN.value)
Ejemplo n.º 26
0
    def test_enum_as_const(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        def enum_const(x: Color) -> bool:
            return x == Color.RED

        # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
        # is supported.
        with torch._jit_internal._disable_emit_hooks():
            scripted = torch.jit.script(enum_const)

        FileCheck() \
            .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \
            .check_next("aten::eq") \
            .check_next("return") \
            .run(str(scripted.graph))

        self.assertEqual(scripted(Color.RED), True)
        self.assertEqual(scripted(Color.GREEN), False)
Ejemplo n.º 27
0
    def test_enum_ivalue_type(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        def is_color_enum(x: Any):
            return isinstance(x, Color)

        # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
        # is supported.
        with torch._jit_internal._disable_emit_hooks():
            scripted_is_color_enum = torch.jit.script(is_color_enum)

        FileCheck() \
            .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \
            .check_next("return") \
            .run(str(scripted_is_color_enum.graph))

        self.assertEqual(scripted_is_color_enum(Color.RED), True)
        self.assertEqual(scripted_is_color_enum(Color.GREEN), True)
        self.assertEqual(scripted_is_color_enum(1), False)
Ejemplo n.º 28
0
    def test_class_specialization(self):
        class Foo(object):  # noqa: B903
            def __init__(self, x, y):
                self.x = x
                self.y = y

        make_global(Foo)  # see [local resolution in python]

        def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor:
            a, b = tup
            return foo.x + foo2.y + a.x + b.y

        # create from python
        x = torch.ones(2, 3)
        y = torch.zeros(2, 3)
        f = Foo(x, y)
        f2 = Foo(x * 2, y * 3)
        f3 = Foo(x * 4, y * 4)

        input = (f, f2, (f, f3))
        sfoo = self.checkScript(use_foo, input)
        graphstr = str(sfoo.graph_for(*input))
        FileCheck().check_count("prim::GetAttr", 4).run(graphstr)
Ejemplo n.º 29
0
    def test_qat_and_script(self):
        class TwoLayerLinear(nn.Module):
            def __init__(self):
                super(TwoLayerLinear, self).__init__()
                self.fc1 = nn.Linear(5, 5)
                self.fc2 = nn.Linear(5, 5)

            def forward(self, x):
                x = self.fc1(x)
                return self.fc2(x)

        class Model(nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.subm = TwoLayerLinear()
                self.fc = nn.Linear(5, 5)

            def forward(self, x):
                x = self.subm(x)
                x = self.fc(x)
                return x

        model = Model()
        qengine = torch.backends.quantized.engine
        qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)}

        # symbolically trace
        model = symbolic_trace(model)
        model = prepare_qat_fx(model, qconfig_dict)

        # ensure scripting works
        scripted = torch.jit.script(model)
        # run one round to make sure model runs
        x = torch.randn(5, 5)
        scripted(x)
        FileCheck().check_count('FakeQuantize = prim::GetAttr[name="activation_post_process', 4, exactly=True) \
                   .run(scripted.graph)
Ejemplo n.º 30
0
    def test_error_stack_module(self):
        def d(x):
            # type: (int) -> int
            return x + 10

        def c(x):
            return d("hello") + d(x)

        def b(x):
            return c(x)

        class Submodule(torch.nn.Module):
            def __init__(self):
                super(Submodule, self).__init__()

            def forward(self, x):
                return b(x)

        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.submodule = Submodule()

            def some_method(self, y):
                return y + self.submodule(y)

            def forward(self, x):
                return self.some_method(x)

        try:
            scripted = torch.jit.script(M())
        except RuntimeError as e:
            checker = FileCheck()
            checker.check("Expected a value of type 'int'")
            checker.check("'c' is being compiled since it was called from 'b'")
            checker.check("'b' is being compiled since it was called from")
            checker.run(str(e))