Ejemplo n.º 1
0
    def test_list_sort(self):
        template = dedent('''
        def func():
            li_1 = {list_create}
            li_2 = {list_create}
            li_3 = {list_create}
            li_1.sort()
            li_2.sort(reverse=True)
            li_4 = sorted(li_3)
            return li_1, li_2, li_3, li_4
        ''')

        lists = ["[]", "[1, 3, 2]", "[True, False, True]", "[1.2, .2, 3.2]",
                 "[torch.tensor(1.0), torch.tensor(0.2), torch.tensor(0.5)]",
                 "[torch.tensor(5), torch.tensor(-2), torch.tensor(4)]"]
        for li in lists:
            code = template.format(list_create=li)
            scope = {}
            exec(code, globals(), scope)
            cu = torch.jit.CompilationUnit(code)
            t1 = cu.func()
            t2 = scope['func']()
            self.assertEqual(t1, t2)

        def test_fail(x):
            # type: (List[Tensor]) -> List[Tensor]
            x.sort()
            return x

        self.checkScriptRaisesRegex(test_fail, (([torch.zeros([2]), torch.zeros([2])],)), Exception,
                                    "bool value of Tensor with more than one value")

        @torch.jit.script
        def test_mutation():
            a = [1, 2, 3]
            a.sort()
            return a

        test_mutation()
        FileCheck().check("aten::sort").run(test_mutation.graph_for())

        def test_sorted_copy():
            a = [torch.tensor(2), torch.tensor(0), torch.tensor(1)]
            b = sorted(a)
            a[0] = torch.tensor(10)
            return a, b

        self.checkScript(test_sorted_copy, ())
Ejemplo n.º 2
0
    def test_freeze_module_in_training_mode(self):
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv1 = nn.Conv2d(1, 32, 3, 1)
                self.conv2 = nn.Conv2d(32, 64, 3, 1)
                self.dropout1 = nn.Dropout2d(0.25)
                self.dropout2 = nn.Dropout2d(0.5)
                self.fc1 = nn.Linear(9216, 128)
                self.fc2 = nn.Linear(128, 10)

            def forward(self, x):
                x = self.conv1(x)
                x = nn.functional.relu(x)
                x = self.conv2(x)
                x = nn.functional.max_pool2d(x, 2)
                x = self.dropout1(x)
                x = torch.flatten(x, 1)
                x = self.fc1(x)
                x = nn.functional.relu(x)
                x = self.dropout2(x)
                x = self.fc2(x)
                output = nn.functional.log_softmax(x, dim=1)
                return output

        model = torch.jit.script(Net())
        model.train()

        with self.assertRaisesRegex(RuntimeError, 'Freezing module in training mode is not yet supported'):
            mTrain_freezed = torch._C._freeze_module(model._c)

        model.eval()
        mEval_freezed = torch._C._freeze_module(model._c)
        self.assertFalse(mEval_freezed.hasattr('conv1'))
        self.assertFalse(mEval_freezed.hasattr('conv2'))
        self.assertFalse(mEval_freezed.hasattr('dropout1'))
        self.assertFalse(mEval_freezed.hasattr('training'))
        self.assertFalse(mEval_freezed.hasattr('fc1'))
        self.assertFalse(mEval_freezed.hasattr('dropout2'))
        self.assertFalse(mEval_freezed.hasattr('fc2'))
        with self.assertRaisesRegex(RuntimeError, "does not have a field with name 'state_dict'"):
            print(mEval_freezed.state_dict())
        buffer = io.BytesIO()
        torch.jit.save(mEval_freezed, buffer)
        buffer.seek(0)
        m = torch.jit.load(buffer)
        FileCheck().check_not('GetAttr[name=') \
                   .run(m._c._get_method('forward').graph)
Ejemplo n.º 3
0
    def test_concat_invariant_cuda(self):
        # Invariant: the output of prim::FusedConcat may
        # not be an input to any node inside the FusionGroup.
        def fn(x, y, z):
            x1 = x + y
            y1 = x - y
            w = torch.cat([x1, y1])
            return w + z

        x = torch.randn(2, 2, dtype=torch.float, device='cuda')
        y = torch.randn(2, 2, dtype=torch.float, device='cuda')
        z = torch.randn(4, 2, dtype=torch.float, device='cuda')
        ge = self.checkTrace(fn, (x, y, z))
        graph = ge.graph_for(x, y, z)
        self.assertAllFused(graph, except_for={'aten::add'})
        FileCheck().check("FusedConcat").check_next("return").run(str(graph))
Ejemplo n.º 4
0
    def test_union_type_refinement_statically_true(self):
        @torch.jit.script
        def fn(x: Union[List[int], int]) -> Union[List[int], int]:
            if not torch.jit.isinstance(x, (int, List[int])):
                return x
            else:
                l = [1, 2, 3]
                y: Union[List[int], int] = l
                return y

        s = fn.graph

        # Check that we don't have any branching statements
        FileCheck().check_not("block0()")    \
            .check_not("block1()")           \
            .run(s)
Ejemplo n.º 5
0
    def test_chunk_distributes_cuda(self):
        def f(x, y):
            z1, z2 = (x + y).chunk(2, dim=1)
            return z1 * z2

        x = torch.randn(4, 4, dtype=torch.float, device='cuda')
        y = torch.randn(4, 4, dtype=torch.float, device='cuda')

        ge = self.checkTrace(f, (x, y))
        graph = ge.graph_for(x, y)
        # XXX: The old fuser does broadcast_tensors but the new fuser doesn't.
        # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \
        #     .check_count('ConstantChunk', 2, exactly=True).run(str(graph))
        FileCheck().check("with " + FUSION_GROUP + "_").check_count(
            "ConstantChunk", 1, exactly=True
        ).run(str(graph))
Ejemplo n.º 6
0
    def test_warn_only_once(self):
        @torch.jit.script
        def fn():
            for _ in range(10):
                warnings.warn("I am warning you")

        f = io.StringIO()
        with redirect_stderr(f):
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=1,
                exactly=True) \
            .run(f.getvalue())
Ejemplo n.º 7
0
    def test_warn_multiple_calls_multiple_warnings(self):
        @torch.jit.script
        def fn():
            warnings.warn("I am warning you")

        f = io.StringIO()
        with redirect_stderr(f):
            fn()
            fn()

        FileCheck() \
            .check_count(
                str="UserWarning: I am warning you",
                count=2,
                exactly=True) \
            .run(f.getvalue())
Ejemplo n.º 8
0
    def test_peephole_cuda(self):
        a = torch.tensor([0.4], device='cpu')
        b = torch.tensor([0.7], device='cuda')
        c = torch.tensor([0.7], device='cuda')

        def f(x, y):
            return x.type_as(y)

        trace = torch.jit.trace(f, (a, c))
        s = str(trace.graph)
        self.run_pass('peephole', trace.graph)
        self.assertEqual(s, str(trace.graph))
        trace = torch.jit.trace(f, (b, c))
        self.run_pass('peephole', trace.graph)
        self.run_pass('dce', trace.graph)
        FileCheck().check_not("type_as").run(str(trace.graph))
Ejemplo n.º 9
0
    def test_dist_backward(self):
        if self.rank != 0:
            return

        @torch.jit.script
        def dist_backward_script(context_id: int, loss: torch.Tensor):
            dist_autograd.backward(context_id, [loss])

        FileCheck().check("dist_backward").run(str(dist_backward_script.graph))
        with dist_autograd.context() as context_id:
            t1 = torch.rand(3, 3, requires_grad=True)
            t2 = torch.rand(3, 3, requires_grad=True)
            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
            loss = rpc.rpc_sync(dst_worker_name, torch.add,
                                args=(t1, t2)).sum()
            dist_backward_script(context_id, loss)
Ejemplo n.º 10
0
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(
            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
        self.assertTrue(len(list(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        hy, cy = module(*inputs)
        (hy + cy).sum().backward()
        backward = backward_graph(module)
        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
                                                  "aten::_grad_sum_to_size"))
Ejemplo n.º 11
0
    def test_stacktrace_interface_call(self):
        @torch.jit.interface
        class Forward(torch.nn.Module):
            def forward(self, x) -> torch.Tensor:
                pass

            def forwardError(self, x) -> torch.Tensor:
                pass

        class B(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return x

            def forwardError(self, x):
                return self.call() + x

            def call(self):
                return torch.ones(-1)

        class A(torch.nn.Module):
            b : Forward

            def __init__(self):
                super().__init__()
                self.b = B()

            def forward(self):
                self.b.forward(torch.ones(1))
                self.b.forwardError(torch.ones(1))

        a = torch.jit.script(A())
        torch._C._enable_mobile_interface_call_export()
        buffer = io.BytesIO(a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True))
        buffer.seek(0)
        mobile_module = _load_for_lite_interpreter(buffer)
        try:
            mobile_module()
            self.assertTrue(False)
        except RuntimeError as exp:
            FileCheck().check("Trying to create tensor with negative dimension") \
                .check("Traceback of TorchScript") \
                .check("self.b.forwardError").check_next("~~~~~~~~~~~~~~~~~~~ <--- HERE") \
                .check("return self.call").check_next("~~~~~~~~~ <--- HERE") \
                .check("return torch.ones").check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
Ejemplo n.º 12
0
    def validate_transform_conv1d_to_conv2d(self,
                                            pattern_count_transformed_map,
                                            pattern_count_optimized_map,
                                            data_shape):
        module_instance = self
        scripted_model = torch.jit.script(module_instance)
        scripted_model.eval()
        input_data = torch.normal(1, 20, size=data_shape)
        ref_result = scripted_model(input_data)
        torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c)
        optimized_scripted_model = optimize_for_mobile(scripted_model)

        buffer = io.BytesIO()
        torch.jit.save(scripted_model, buffer)
        buffer.seek(0)
        deserialized_scripted_model = torch.jit.load(buffer)

        for pattern, v in pattern_count_transformed_map.items():
            if (v == 0):
                FileCheck().check(pattern).run(
                    deserialized_scripted_model.graph)
            elif (v == -1):
                FileCheck().check_not(pattern).run(
                    deserialized_scripted_model.graph)
            else:
                FileCheck().check_count(pattern, v, exactly=True).run(
                    deserialized_scripted_model.graph)
        transformed_result = deserialized_scripted_model(input_data)
        torch.testing.assert_allclose(ref_result,
                                      transformed_result,
                                      rtol=1e-2,
                                      atol=1e-3)

        optimized_buffer = io.BytesIO()
        torch.jit.save(optimized_scripted_model, optimized_buffer)
        optimized_buffer.seek(0)
        deserialized_optimized_scripted_model = torch.jit.load(
            optimized_buffer)

        for pattern, v in pattern_count_optimized_map.items():
            if (v == 0):
                FileCheck().check(pattern).run(
                    deserialized_optimized_scripted_model.graph)
            elif (v == -1):
                FileCheck().check_not(pattern).run(
                    deserialized_optimized_scripted_model.graph)
            else:
                FileCheck().check_count(pattern, v, exactly=True).run(
                    deserialized_optimized_scripted_model.graph)
        xnnpack_result = deserialized_optimized_scripted_model(input_data)
        torch.testing.assert_allclose(ref_result,
                                      xnnpack_result,
                                      rtol=1e-2,
                                      atol=1e-3)
Ejemplo n.º 13
0
    def test_fake_dispatch_keys(self):
        with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            x = torch.rand([4])
            f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
            f.run(torch._C._dispatch_key_set(x))

            with torch.inference_mode():
                x = torch.rand([4])
                y = x + x
                FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
                FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
    def test_merges_down(self):
        # o     x --> o
        # |           ^
        #  \_________/
        def fn(v, w, x, y):
            a = v * w
            b = torch.ones(int(y))
            c = b * a
            return a, c

        graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)

        num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3
        # add moved down
        g_str = str(graph)
        FileCheck().check_not("aten::add").run(g_str[0:g_str.find("return")])
        self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
Ejemplo n.º 15
0
    def test_lists_insert(self):
        def successful_remove():
            a: List[int] = []
            a.insert(0, 1)
            a.insert(0, 2)
            a.insert(-10, 3)
            a.insert(-9, 4)
            a.insert(10, 5)
            return a

        fn = torch.jit.script(successful_remove)
        graph = fn.graph
        torch._C._jit_pass_remove_mutation(graph)
        torch._C._jit_pass_constant_propagation(graph)
        FileCheck().check("graph").check_next("Constant").check_next(
            "return").run(graph)
        self.assertEqual(successful_remove(), fn())
Ejemplo n.º 16
0
    def checkGraphModeOp(self, module, data, quantized_op, tracing=False, debug=False, check=True, eval_mode=True, dynamic=False):
        if debug:
            print('Testing:', str(module))
        qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}

        if eval_mode:
            module = module.eval()
        if dynamic:
            qconfig_dict = {'': default_dynamic_qconfig}
            inputs = data
        else:
            *inputs, target = data[0]
        model = get_script_module(module, tracing, inputs).eval()
        if debug:
            print('input graph:', model.graph)
        models = {}
        outputs = {}
        for d in [True, False]:
            # TODO: _test_only_eval_fn --> default_eval_fn
            if dynamic:
                models[d] = quantize_dynamic_jit(model, qconfig_dict, debug=d)
                # make sure it runs
                outputs[d] = models[d](inputs)
            else:
                # module under test can contain in-place ops, and we depend on
                # input data staying constant for comparisons
                data_copy = copy.deepcopy(data)
                models[d] = quantize_jit(
                    model, qconfig_dict, test_only_eval_fn, [data_copy], inplace=False,
                    debug=d)
                # make sure it runs
                outputs[d] = models[d](*inputs)

        if debug:
            print('debug graph:', models[True].graph)
            print('non debug graph:', models[False].graph)

        if check:
            # debug and non-debug option should have the same numerics
            self.assertEqual(outputs[True], outputs[False])

            # non debug graph should produce quantized op
            FileCheck().check(quantized_op) \
                       .run(models[False].graph)

        return models[False]
Ejemplo n.º 17
0
    def test_enum_value_types(self):
        global IntEnum

        class IntEnum(Enum):
            FOO = 1
            BAR = 2

        global FloatEnum

        class FloatEnum(Enum):
            FOO = 1.2
            BAR = 2.3

        global StringEnum

        class StringEnum(Enum):
            FOO = "foo as in foo bar"
            BAR = "bar as in foo bar"

        def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum):
            return (a.name, b.name, c.name)

        # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
        # is supported.
        with torch._jit_internal._disable_emit_hooks():
            scripted = torch.jit.script(supported_enum_types)

        FileCheck() \
            .check("IntEnum") \
            .check("FloatEnum") \
            .check("StringEnum") \
            .run(str(scripted.graph))

        global TensorEnum

        class TensorEnum(Enum):
            FOO = torch.tensor(0)
            BAR = torch.tensor(1)

        def unsupported_enum_types(a: TensorEnum):
            return a.name

        with self.assertRaisesRegex(
                RuntimeError, "Cannot create Enum with value type 'Tensor'"):
            torch.jit.script(unsupported_enum_types)
    def test_differentiable_graph_ops_requires_grad(self):
        x = torch.randn(8, 2, dtype=torch.float).requires_grad_()
        y = torch.randn(8, 2, dtype=torch.float)

        def t(x: torch.Tensor, y: torch.Tensor, flag: bool):
            o = x + 1.0
            o1 = torch.relu(o)
            o = y + 1.5
            o2 = torch.relu(o)
            o3 = o1 + o2

            if flag:
                o = o1 + 1.0
                oo1 = torch.relu(o)
                o = o2 + 2.5
                oo2 = torch.relu(o)
                oo3 = oo1 + oo2
            else:
                o = o1 * 1.0
                oo1 = torch.relu(o)
                o = o2 * 2.0
                oo2 = torch.relu(o)
                oo3 = oo1 + oo2

            return o1, o2, o3, oo1, oo2, oo3

        with enable_profiling_mode_for_profiling_tests():

            t_jit = torch.jit.script(t)
            jit_o = t_jit(x, y, False)
            jit_o = t_jit(x, y, False)
            o = t(x, y, False)

            FileCheck().check("prim::DifferentiableGraph").run(
                t_jit.graph_for(x, y, False))
            # validate the differentiableGraphOps are marking proper requires_grad
            for oo, jit_oo in zip(o, jit_o):
                self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
                self.assertEqual(oo, jit_oo)
            # one more runs to trigger fusion
            jit_o = t_jit(x, y, False)
            for oo, jit_oo in zip(o, jit_o):
                self.assertEqual(oo.dtype, jit_oo.dtype)
                self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
                self.assertEqual(oo, jit_oo)
Ejemplo n.º 19
0
    def test_list_keyword(self):
        def foo():
            return list([1, 2, 3]), list(("a", "b")), list(range(5)), list("abcdefg")  # noqa: C410

        self.checkScript(foo, ())

        def foo2():
            x: List[int] = list()
            x.append(1)
            return x,

        self.checkScript(foo2, ())

        def foo3():
            return list(list("abc"))

        self.checkScript(foo3, ())
        FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)
Ejemplo n.º 20
0
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        return
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(
            forward_graph, FUSION_GROUP, 1, consider_subgraphs=True)
        self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        with enable_profiling_mode_for_profiling_tests(True):
            hy, cy = module(*inputs)
            warmup_backward((hy + cy).sum())
            backward = backward_graph(module)
        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
                                                  "aten::_grad_sum_to_size"))
Ejemplo n.º 21
0
    def test_finalize_for_linear_dynamic(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.fc = torch.nn.Linear(5, 5).float()

            def forward(self, x):
                return self.fc(x)

        data = [(torch.rand((1, 5), dtype=torch.float),
                 torch.randint(0, 1, (1, ), dtype=torch.long))
                for _ in range(2)]
        qconfig_dict = {'': default_dynamic_qconfig}
        model = torch.jit.script(M()).eval()
        model = quantize_dynamic_script(model, qconfig_dict,
                                        _test_only_eval_fn, [data])
        FileCheck().check("quantized::linear_dynamic") \
                   .run(model.graph)
    def test_merge_respects_aliasing(self):
        def fn(x, k, cond):
            y = x * 1.1
            y = y * k
            y = y * 2.2
            if bool(cond):
                z1 = y[0]
                z2 = y[1]
                z1.add_(3)
                out = z2 + k + 3.3
                out = out * out
                return out

        graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1)
        # z2 did did not get merged into the subgraph
        FileCheck().check("prim::If").check("aten::select").check_next("aten::select")\
            .check_next("aten::add_").check("Differentiable").run(graph)
        self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
Ejemplo n.º 23
0
    def test_enum_name(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        @torch.jit.script
        def enum_name(x: Color) -> str:
            return x.name

        FileCheck() \
            .check("Color") \
            .check_next("prim::EnumName") \
            .check_next("return") \
            .run(str(enum_name.graph))

        self.assertEqual(enum_name(Color.RED), Color.RED.name)
        self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
Ejemplo n.º 24
0
    def test_enum_as_const(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_const(x: Color) -> bool:
            return x == Color.RED

        FileCheck() \
            .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \
            .check_next("aten::eq") \
            .check_next("return") \
            .run(str(enum_const.graph))

        self.assertEqual(enum_const(Color.RED), True)
        self.assertEqual(enum_const(Color.GREEN), False)
Ejemplo n.º 25
0
    def test_enum_value(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def enum_value(x: Color) -> int:
            return x.value

        FileCheck() \
            .check("Color") \
            .check_next("prim::EnumValue") \
            .check_next("return") \
            .run(str(enum_value.graph))

        self.assertEqual(enum_value(Color.RED), Color.RED.value)
        self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
Ejemplo n.º 26
0
    def test_enum_ivalue_type(self):
        class Color(Enum):
            RED = 1
            GREEN = 2

        make_global(Color)

        @torch.jit.script
        def is_color_enum(x: Any):
            return isinstance(x, Color)

        FileCheck() \
            .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \
            .check_next("return") \
            .run(str(is_color_enum.graph))

        self.assertEqual(is_color_enum(Color.RED), True)
        self.assertEqual(is_color_enum(Color.GREEN), True)
        self.assertEqual(is_color_enum(1), False)
Ejemplo n.º 27
0
    def test_prepare_dynamic_lstm(self):
        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()
                self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)

            def forward(self, x):
                return self.lstm(x)

        from torch.quantization.observer import default_dynamic_quant_observer, _MinMaxTensorListObserver
        qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
                                 weight=_MinMaxTensorListObserver)
        m = torch.jit.script(M())
        m = prepare_dynamic_script(m, {'': qconfig})
        assert len(attrs_with_prefix(m.lstm, '_observer_')) == 1
        FileCheck().check('_MinMaxTensorListObserver = prim::GetAttr[name="_observer_0') \
                   .check("aten::lstm") \
                   .check("return") \
                   .run(str(get_module_method(m, 'lstm', 'forward__0').graph))
Ejemplo n.º 28
0
    def test_enum_comp(self):
        global Color

        class Color(Enum):
            RED = 1
            GREEN = 2

        def enum_comp(x: Color, y: Color) -> bool:
            return x == y

        # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization
        # is supported.
        with torch._jit_internal._disable_emit_hooks():
            scripted_enum_comp = torch.jit.script(enum_comp)

        FileCheck().check("aten::eq").run(str(scripted_enum_comp.graph))

        self.assertEqual(scripted_enum_comp(Color.RED, Color.RED), True)

        self.assertEqual(scripted_enum_comp(Color.RED, Color.GREEN), False)
Ejemplo n.º 29
0
    def test_ignore_with_types(self):
        @torch.jit.ignore
        def fn(x: Dict[str, Optional[torch.Tensor]]):
            return x + 10

        class M(torch.nn.Module):
            def __init__(self):
                super(M, self).__init__()

            def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
                self.dropout_modality(in_batch)
                fn(in_batch)
                return torch.tensor(1)

            @torch.jit.ignore
            def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
                return in_batch

        sm = torch.jit.script(M())
        FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
    def test_shape_analysis(self):
        @torch.jit.script
        def foo(x, y):
            return x * y

        inputs = list(foo.graph.inputs())

        def prop_shapes_on_graph(inp0, inp1):
            inputs[0].setType(inputs[0].type().with_sizes(inp0))
            inputs[1].setType(inputs[1].type().with_sizes(inp1))
            torch._C._jit_pass_propagate_shapes_on_graph(foo.graph)

        prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5])
        FileCheck().check("1, 7, 6, 5").run(foo.graph)

        # None implicitly creates a new symbolic symbol
        prop_shapes_on_graph([None, None], [None, None, None])
        output_shape = foo.graph.findNode(
            "aten::mul").output().type().symbolic_sizes()
        inp0_shape = inputs[0].type().symbolic_sizes()
        inp1_shape = inputs[1].type().symbolic_sizes()

        # output shape dim 0 should be taken from the second inp dim0
        # other two dims we cannot infer and are given a new symbolic shape
        self.assertEqual(output_shape[0], inp1_shape[0])
        self.assertFalse(output_shape[1] in inp0_shape + inp1_shape)
        self.assertFalse(output_shape[2] in inp0_shape + inp1_shape)

        # XXX: symbolic shapes are represented with an increasing counter of unique
        # values, use `_new_symbolic_shape_symbol` api instead of specifying negative
        # dimensions directly so there is no chance of collision between manual number
        # and current counter value.
        sym1 = torch._C._new_symbolic_shape_symbol()
        sym2 = torch._C._new_symbolic_shape_symbol()
        sym3 = torch._C._new_symbolic_shape_symbol()
        prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3])
        output_shape = foo.graph.findNode(
            "aten::mul").output().type().symbolic_sizes()
        self.assertEqual(output_shape[0], sym1)
        self.assertEqual(output_shape[1], sym2)
        self.assertEqual(output_shape[2], sym3)