def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False, check=True, eval_mode=True, dynamic=False, qconfig=None): if debug: print('Testing:', str(module)) qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} if eval_mode: module = module.eval() if dynamic: qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig} model = get_script_module(module, tracing, inputs[0]).eval() if debug: print('input graph:', model.graph) models = {} outputs = {} for d in [True, False]: if dynamic: models[d] = quantize_dynamic_jit(model, qconfig_dict, debug=d) # make sure it runs outputs[d] = models[d](inputs) else: # module under test can contain in-place ops, and we depend on # input data staying constant for comparisons inputs_copy = copy.deepcopy(inputs) models[d] = quantize_jit( model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False, debug=d) # make sure it runs outputs[d] = models[d](*inputs[0]) if debug: print('debug graph:', models[True].graph) print('non debug graph:', models[False].graph) if check: # debug and non-debug option should have the same numerics self.assertEqual(outputs[True], outputs[False]) # non debug graph should produce quantized op FileCheck().check(quantized_op) \ .run(models[False].graph) return models[False]
def test_list_sort(self): template = dedent(''' def func(): li_1 = {list_create} li_2 = {list_create} li_3 = {list_create} li_1.sort() li_2.sort(reverse=True) li_4 = sorted(li_3) return li_1, li_2, li_3, li_4 ''') lists = [ "[]", "[1, 3, 2]", "[True, False, True]", "[1.2, .2, 3.2]", "[torch.tensor(1.0), torch.tensor(0.2), torch.tensor(0.5)]", "[torch.tensor(5), torch.tensor(-2), torch.tensor(4)]" ] for li in lists: code = template.format(list_create=li) scope = {} exec(code, globals(), scope) cu = torch.jit.CompilationUnit(code) t1 = cu.func() t2 = scope['func']() self.assertEqual(t1, t2) def test_fail(x): # type: (List[Tensor]) -> List[Tensor] x.sort() return x self.checkScriptRaisesRegex( test_fail, (([torch.zeros([2]), torch.zeros([2])], )), Exception, "bool value of Tensor with more than one value") @torch.jit.script def test_mutation(): a = [1, 2, 3] a.sort() return a test_mutation() FileCheck().check("aten::sort").run(test_mutation.graph_for())
def test_warn_only_once_in_loop_func(self): def w(): warnings.warn("I am warning you") @torch.jit.script def fn(): for _ in range(10): w() f = io.StringIO() with redirect_stderr(f): fn() FileCheck() \ .check_count( str="UserWarning: I am warning you", count=1, exactly=True) \ .run(f.getvalue())
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertTrue(len(list(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) hy, cy = module(*inputs) (hy + cy).sum().backward() backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def test_prepare_dynamic_lstm(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) def forward(self, x): return self.lstm(x) from torch.quantization.observer import default_dynamic_quant_observer, _MinMaxTensorListObserver qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, weight=_MinMaxTensorListObserver) m = torch.jit.script(M()) m = prepare_dynamic_script(m, {'': qconfig}) assert len(attrs_with_prefix(m.lstm, '_observer_')) == 1 FileCheck().check('_MinMaxTensorListObserver = prim::GetAttr[name="_observer_0') \ .check("aten::lstm") \ .check("return") \ .run(str(get_module_method(m, 'lstm', 'forward__0').graph))
def test_enum_as_const(self): global Color class Color(Enum): RED = 1 GREEN = 2 @torch.jit.script def enum_const(x: Color) -> bool: return x == Color.RED FileCheck() \ .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \ .check_next("aten::eq") \ .check_next("return") \ .run(str(enum_const.graph)) self.assertEqual(enum_const(Color.RED), True) self.assertEqual(enum_const(Color.GREEN), False)
def test_enum_ivalue_type(self): global Color class Color(Enum): RED = 1 GREEN = 2 @torch.jit.script def is_color_enum(x: Any): return isinstance(x, Color) FileCheck() \ .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \ .check_next("return") \ .run(str(is_color_enum.graph)) self.assertEqual(is_color_enum(Color.RED), True) self.assertEqual(is_color_enum(Color.GREEN), True) self.assertEqual(is_color_enum(1), False)
def test_enum_value(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def enum_value(x: Color) -> int: return x.value FileCheck() \ .check("Color") \ .check_next("prim::EnumValue") \ .check_next("return") \ .run(str(enum_value.graph)) self.assertEqual(enum_value(Color.RED), Color.RED.value) self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
def test_enum_name(self): global Color class Color(Enum): RED = 1 GREEN = 2 @torch.jit.script def enum_name(x: Color) -> str: return x.name FileCheck() \ .check("Color") \ .check_next("prim::EnumName") \ .check_next("return") \ .run(str(enum_name.graph)) self.assertEqual(enum_name(Color.RED), Color.RED.name) self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
def test_enum_comp(self): global Color class Color(Enum): RED = 1 GREEN = 2 def enum_comp(x: Color, y: Color) -> bool: return x == y # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization # is supported. with torch._jit_internal._disable_emit_hooks(): scripted_enum_comp = torch.jit.script(enum_comp) FileCheck().check("aten::eq").run(str(scripted_enum_comp.graph)) self.assertEqual(scripted_enum_comp(Color.RED, Color.RED), True) self.assertEqual(scripted_enum_comp(Color.RED, Color.GREEN), False)
def test_ignore_with_types(self): @torch.jit.ignore def fn(x: Dict[str, Optional[torch.Tensor]]): return x + 10 class M(torch.nn.Module): def __init__(self): super(M, self).__init__() def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor: self.dropout_modality(in_batch) fn(in_batch) return torch.tensor(1) @torch.jit.ignore def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]: return in_batch sm = torch.jit.script(M()) FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
def test_shape_analysis(self): @torch.jit.script def foo(x, y): return x * y inputs = list(foo.graph.inputs()) def prop_shapes_on_graph(inp0, inp1): inputs[0].setType(inputs[0].type().with_sizes(inp0)) inputs[1].setType(inputs[1].type().with_sizes(inp1)) torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5]) FileCheck().check("1, 7, 6, 5").run(foo.graph) # None implicitly creates a new symbolic symbol prop_shapes_on_graph([None, None], [None, None, None]) output_shape = foo.graph.findNode( "aten::mul").output().type().symbolic_sizes() inp0_shape = inputs[0].type().symbolic_sizes() inp1_shape = inputs[1].type().symbolic_sizes() # output shape dim 0 should be taken from the second inp dim0 # other two dims we cannot infer and are given a new symbolic shape self.assertEqual(output_shape[0], inp1_shape[0]) self.assertFalse(output_shape[1] in inp0_shape + inp1_shape) self.assertFalse(output_shape[2] in inp0_shape + inp1_shape) # XXX: symbolic shapes are represented with an increasing counter of unique # values, use `_new_symbolic_shape_symbol` api instead of specifying negative # dimensions directly so there is no chance of collision between manual number # and current counter value. sym1 = torch._C._new_symbolic_shape_symbol() sym2 = torch._C._new_symbolic_shape_symbol() sym3 = torch._C._new_symbolic_shape_symbol() prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3]) output_shape = foo.graph.findNode( "aten::mul").output().type().symbolic_sizes() self.assertEqual(output_shape[0], sym1) self.assertEqual(output_shape[1], sym2) self.assertEqual(output_shape[2], sym3)
def test_enum_value_types(self): global IntEnum class IntEnum(Enum): FOO = 1 BAR = 2 global FloatEnum class FloatEnum(Enum): FOO = 1.2 BAR = 2.3 global StringEnum class StringEnum(Enum): FOO = "foo as in foo bar" BAR = "bar as in foo bar" @torch.jit.script def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): return (a.name, b.name, c.name) FileCheck() \ .check("IntEnum") \ .check("FloatEnum") \ .check("StringEnum") \ .run(str(supported_enum_types.graph)) global TensorEnum class TensorEnum(Enum): FOO = torch.tensor(0) BAR = torch.tensor(1) def unsupported_enum_types(a: TensorEnum): return a.name with self.assertRaisesRegex( RuntimeError, "Cannot create Enum with value type 'Tensor'"): torch.jit.script(unsupported_enum_types)
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) return forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly(forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) with enable_profiling_mode(True): hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def test_prepare_dynamic(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.fc = torch.nn.Linear(5, 5) def forward(self, x): return self.fc(x) m = torch.jit.script(M()) m = prepare_dynamic_script(m, {'': default_dynamic_qconfig}) # for input of FC for dynamic quant assert len(attrs_with_prefix(m, '_observer_')) == 1 # for weight assert len(attrs_with_prefix(m.fc, '_observer_')) == 1 FileCheck().check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \ .check('prim::GetAttr[name="fc"]') \ .check('prim::CallMethod') \ .check_not('Observer = prim::GetAttr[name="_observer_') \ .run(m.graph)
def test_list_type_refinement_defaults_to_Any_list_creation(self): def fn(x): tup1 = ("foo", torch.tensor(2)) tup2 = ("bar", {"23": torch.tensor(3)}) tup3 = ("baz", x) l = list((tup1, tup2)) # noqa: C410 l.append(tup3) tup4 = l[0] if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]): t = tup4[1] if isinstance(t, torch.Tensor): l[0] = (tup4[0], torch.add(t, t)) return l self.checkScript(fn, (torch.arange(5), )) graph = torch.jit.script(fn).graph # Check that we're making a `List[Tuple[str, Any]]` FileCheck().check("(str, Union[Tensor, Dict(str, Tensor)])" "[] = prim::ListConstruct()").run(graph)
def test_get_gradients(self): dst_rank = self.rank @torch.jit.script def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]): return dist_autograd.get_gradients(context_id) FileCheck().check("get_gradients").run(str(dist_get_gradients.graph)) with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) t3 = torch.add(t1, t2) dist_autograd.backward(context_id, [t3.sum()]) grads = dist_get_gradients(context_id) self.assertEqual(2, len(grads)) self.assertIn(t1, grads) self.assertIn(t2, grads) self.assertEqual(torch.ones(3, 3), grads[t1]) self.assertEqual(torch.ones(3, 3), grads[t2])
def checkGraphModeOp(self, module, data, quantized_op, tracing=False, debug=False, check=True, eval_mode=True, dynamic=False): if debug: print('Testing:', str(module)) qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} if eval_mode: module = module.eval() if dynamic: qconfig_dict = {'': default_dynamic_qconfig} inputs = data else: *inputs, target = data[0] model = get_script_module(module, tracing, inputs).eval() models = {} outputs = {} for d in [True, False]: # TODO: _test_only_eval_fn --> default_eval_fn if dynamic: models[d] = quantize_dynamic_script(model, qconfig_dict, debug=d) # make sure it runs outputs[d] = models[d](inputs) else: models[d] = quantize_script( model, qconfig_dict, test_only_eval_fn, [data], inplace=False, debug=d) # make sure it runs outputs[d] = models[d](*inputs) if debug: print('debug graph:', models[True].graph) print('non debug graph:', models[False].graph) if check: # debug and non-debug option should have the same numerics self.assertEqual(outputs[True], outputs[False]) # non debug graph should produce quantized op FileCheck().check(quantized_op) \ .run(models[False].graph) return models[False]
def test_error_stack(self): def d(x: int) -> int: return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) def a(x): return b(x) try: scripted = torch.jit.script(a) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("def c(x)") checker.check("def b(x)") checker.check("def a(x)") checker.run(str(e))
def test_enum_comp_diff_classes(self): class Foo(Enum): ITEM1 = 1 ITEM2 = 2 class Bar(Enum): ITEM1 = 1 ITEM2 = 2 make_global(Foo, Bar) @torch.jit.script def enum_comp(x: Foo) -> bool: return x == Bar.ITEM1 FileCheck() \ .check("prim::Constant") \ .check_same("Bar.ITEM1") \ .check("aten::eq") \ .run(str(enum_comp.graph)) self.assertEqual(enum_comp(Foo.ITEM1), False)
def test_enum_value_types(self): class IntEnum(Enum): FOO = 1 BAR = 2 class FloatEnum(Enum): FOO = 1.2 BAR = 2.3 class StringEnum(Enum): FOO = "foo as in foo bar" BAR = "bar as in foo bar" make_global(IntEnum, FloatEnum, StringEnum) @torch.jit.script def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): return (a.name, b.name, c.name) FileCheck() \ .check("IntEnum") \ .check("FloatEnum") \ .check("StringEnum") \ .run(str(supported_enum_types.graph)) class TensorEnum(Enum): FOO = torch.tensor(0) BAR = torch.tensor(1) make_global(TensorEnum) def unsupported_enum_types(a: TensorEnum): return a.name # TODO: rewrite code so that the highlight is not empty. with self.assertRaisesRegexWithHighlight( RuntimeError, "Cannot create Enum with value type 'Tensor'", ""): torch.jit.script(unsupported_enum_types)
def test_tensor_scalar_ops_cuda(self): def should_fuse(x): z = 3. y = x + z return x * y # XXX: right now we only support fusing scalars if # they're constant (#9940) def should_not_fuse(x, z): y = x + int(z) return x * y inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')] ge = self.checkScript(should_fuse, inputs) graph = ge.graph_for(*inputs) fusion_groups = self.findFusionGroups(graph) self.assertEqual(len(fusion_groups), 1) FileCheck().check("aten::add").check("aten::mul").run( str(fusion_groups[0])) inputs = [ torch.randn(2, 2, dtype=torch.float, device='cuda'), torch.tensor(3., dtype=torch.float, device='cuda'), ] ge = self.checkScript(should_not_fuse, inputs) # Check that the fused graph computes correct results when the scalar # input changes. inputs = [ torch.randn(2, 2, dtype=torch.float, device='cuda'), torch.tensor(7., dtype=torch.float, device='cuda'), ] self.assertEqual(ge(*inputs), should_not_fuse(*inputs)) # XXX: The TE fuser supports fusion of non-constant scalars # self.assertGraphContainsExactly( # ge.graph_for(*inputs), FUSION_GROUP, 0, consider_subgraphs=True) self.assertGraphContainsExactly(ge.graph_for(*inputs), FUSION_GROUP, 0, consider_subgraphs=True)
def test_warn_once_per_func(self): def w1(): warnings.warn("I am warning you") def w2(): warnings.warn("I am warning you") @torch.jit.script def fn(): w1() w2() f = io.StringIO() with redirect_stderr(f): fn() FileCheck() \ .check_count( str="UserWarning: I am warning you", count=2, exactly=True) \ .run(f.getvalue())
def test_list_type_refinement_defaults_to_Any_list_comprehension(self): def fn(x): tup1 = ("foo", torch.tensor(2)) tup2 = ("bar", {"23": torch.tensor(3)}) tup3 = ("baz", x) l_ = [tup1, tup2] l = [t for t in l_] # noqa: C416 l.append(tup3) tup4 = l[0] if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]): t = tup4[1] if isinstance(t, torch.Tensor): l[0] = (tup4[0], torch.add(t, t)) return l self.checkScript(fn, (torch.arange(5), )) graph = torch.jit.script(fn).graph print(graph) # Check that we're making a `List[Tuple[str, Any]]` FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph)
def test_enum_value(self): global Color class Color(Enum): RED = 1 GREEN = 2 def enum_value(x: Color) -> int: return x.value # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization # is supported. with torch._jit_internal._disable_emit_hooks(): scripted_enum_value = torch.jit.script(enum_value) FileCheck() \ .check("Color") \ .check_next("prim::EnumValue") \ .check_next("return") \ .run(str(scripted_enum_value.graph)) self.assertEqual(scripted_enum_value(Color.RED), Color.RED.value) self.assertEqual(scripted_enum_value(Color.GREEN), Color.GREEN.value)
def test_enum_as_const(self): global Color class Color(Enum): RED = 1 GREEN = 2 def enum_const(x: Color) -> bool: return x == Color.RED # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization # is supported. with torch._jit_internal._disable_emit_hooks(): scripted = torch.jit.script(enum_const) FileCheck() \ .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \ .check_next("aten::eq") \ .check_next("return") \ .run(str(scripted.graph)) self.assertEqual(scripted(Color.RED), True) self.assertEqual(scripted(Color.GREEN), False)
def test_enum_ivalue_type(self): global Color class Color(Enum): RED = 1 GREEN = 2 def is_color_enum(x: Any): return isinstance(x, Color) # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization # is supported. with torch._jit_internal._disable_emit_hooks(): scripted_is_color_enum = torch.jit.script(is_color_enum) FileCheck() \ .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \ .check_next("return") \ .run(str(scripted_is_color_enum.graph)) self.assertEqual(scripted_is_color_enum(Color.RED), True) self.assertEqual(scripted_is_color_enum(Color.GREEN), True) self.assertEqual(scripted_is_color_enum(1), False)
def test_class_specialization(self): class Foo(object): # noqa: B903 def __init__(self, x, y): self.x = x self.y = y make_global(Foo) # see [local resolution in python] def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor: a, b = tup return foo.x + foo2.y + a.x + b.y # create from python x = torch.ones(2, 3) y = torch.zeros(2, 3) f = Foo(x, y) f2 = Foo(x * 2, y * 3) f3 = Foo(x * 4, y * 4) input = (f, f2, (f, f3)) sfoo = self.checkScript(use_foo, input) graphstr = str(sfoo.graph_for(*input)) FileCheck().check_count("prim::GetAttr", 4).run(graphstr)
def test_qat_and_script(self): class TwoLayerLinear(nn.Module): def __init__(self): super(TwoLayerLinear, self).__init__() self.fc1 = nn.Linear(5, 5) self.fc2 = nn.Linear(5, 5) def forward(self, x): x = self.fc1(x) return self.fc2(x) class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.subm = TwoLayerLinear() self.fc = nn.Linear(5, 5) def forward(self, x): x = self.subm(x) x = self.fc(x) return x model = Model() qengine = torch.backends.quantized.engine qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)} # symbolically trace model = symbolic_trace(model) model = prepare_qat_fx(model, qconfig_dict) # ensure scripting works scripted = torch.jit.script(model) # run one round to make sure model runs x = torch.randn(5, 5) scripted(x) FileCheck().check_count('FakeQuantize = prim::GetAttr[name="activation_post_process', 4, exactly=True) \ .run(scripted.graph)
def test_error_stack_module(self): def d(x): # type: (int) -> int return x + 10 def c(x): return d("hello") + d(x) def b(x): return c(x) class Submodule(torch.nn.Module): def __init__(self): super(Submodule, self).__init__() def forward(self, x): return b(x) class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.submodule = Submodule() def some_method(self, y): return y + self.submodule(y) def forward(self, x): return self.some_method(x) try: scripted = torch.jit.script(M()) except RuntimeError as e: checker = FileCheck() checker.check("Expected a value of type 'int'") checker.check("'c' is being compiled since it was called from 'b'") checker.check("'b' is being compiled since it was called from") checker.run(str(e))