def test_list_sort(self): template = dedent(''' def func(): li_1 = {list_create} li_2 = {list_create} li_3 = {list_create} li_1.sort() li_2.sort(reverse=True) li_4 = sorted(li_3) return li_1, li_2, li_3, li_4 ''') lists = ["[]", "[1, 3, 2]", "[True, False, True]", "[1.2, .2, 3.2]", "[torch.tensor(1.0), torch.tensor(0.2), torch.tensor(0.5)]", "[torch.tensor(5), torch.tensor(-2), torch.tensor(4)]"] for li in lists: code = template.format(list_create=li) scope = {} exec(code, globals(), scope) cu = torch.jit.CompilationUnit(code) t1 = cu.func() t2 = scope['func']() self.assertEqual(t1, t2) def test_fail(x): # type: (List[Tensor]) -> List[Tensor] x.sort() return x self.checkScriptRaisesRegex(test_fail, (([torch.zeros([2]), torch.zeros([2])],)), Exception, "bool value of Tensor with more than one value") @torch.jit.script def test_mutation(): a = [1, 2, 3] a.sort() return a test_mutation() FileCheck().check("aten::sort").run(test_mutation.graph_for()) def test_sorted_copy(): a = [torch.tensor(2), torch.tensor(0), torch.tensor(1)] b = sorted(a) a[0] = torch.tensor(10) return a, b self.checkScript(test_sorted_copy, ())
def test_freeze_module_in_training_mode(self): class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = self.conv2(x) x = nn.functional.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = nn.functional.relu(x) x = self.dropout2(x) x = self.fc2(x) output = nn.functional.log_softmax(x, dim=1) return output model = torch.jit.script(Net()) model.train() with self.assertRaisesRegex(RuntimeError, 'Freezing module in training mode is not yet supported'): mTrain_freezed = torch._C._freeze_module(model._c) model.eval() mEval_freezed = torch._C._freeze_module(model._c) self.assertFalse(mEval_freezed.hasattr('conv1')) self.assertFalse(mEval_freezed.hasattr('conv2')) self.assertFalse(mEval_freezed.hasattr('dropout1')) self.assertFalse(mEval_freezed.hasattr('training')) self.assertFalse(mEval_freezed.hasattr('fc1')) self.assertFalse(mEval_freezed.hasattr('dropout2')) self.assertFalse(mEval_freezed.hasattr('fc2')) with self.assertRaisesRegex(RuntimeError, "does not have a field with name 'state_dict'"): print(mEval_freezed.state_dict()) buffer = io.BytesIO() torch.jit.save(mEval_freezed, buffer) buffer.seek(0) m = torch.jit.load(buffer) FileCheck().check_not('GetAttr[name=') \ .run(m._c._get_method('forward').graph)
def test_concat_invariant_cuda(self): # Invariant: the output of prim::FusedConcat may # not be an input to any node inside the FusionGroup. def fn(x, y, z): x1 = x + y y1 = x - y w = torch.cat([x1, y1]) return w + z x = torch.randn(2, 2, dtype=torch.float, device='cuda') y = torch.randn(2, 2, dtype=torch.float, device='cuda') z = torch.randn(4, 2, dtype=torch.float, device='cuda') ge = self.checkTrace(fn, (x, y, z)) graph = ge.graph_for(x, y, z) self.assertAllFused(graph, except_for={'aten::add'}) FileCheck().check("FusedConcat").check_next("return").run(str(graph))
def test_union_type_refinement_statically_true(self): @torch.jit.script def fn(x: Union[List[int], int]) -> Union[List[int], int]: if not torch.jit.isinstance(x, (int, List[int])): return x else: l = [1, 2, 3] y: Union[List[int], int] = l return y s = fn.graph # Check that we don't have any branching statements FileCheck().check_not("block0()") \ .check_not("block1()") \ .run(s)
def test_chunk_distributes_cuda(self): def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) return z1 * z2 x = torch.randn(4, 4, dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device='cuda') ge = self.checkTrace(f, (x, y)) graph = ge.graph_for(x, y) # XXX: The old fuser does broadcast_tensors but the new fuser doesn't. # FileCheck().check("broadcast_tensors").check('with ' + FUSION_GROUP + '_') \ # .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) FileCheck().check("with " + FUSION_GROUP + "_").check_count( "ConstantChunk", 1, exactly=True ).run(str(graph))
def test_warn_only_once(self): @torch.jit.script def fn(): for _ in range(10): warnings.warn("I am warning you") f = io.StringIO() with redirect_stderr(f): fn() FileCheck() \ .check_count( str="UserWarning: I am warning you", count=1, exactly=True) \ .run(f.getvalue())
def test_warn_multiple_calls_multiple_warnings(self): @torch.jit.script def fn(): warnings.warn("I am warning you") f = io.StringIO() with redirect_stderr(f): fn() fn() FileCheck() \ .check_count( str="UserWarning: I am warning you", count=2, exactly=True) \ .run(f.getvalue())
def test_peephole_cuda(self): a = torch.tensor([0.4], device='cpu') b = torch.tensor([0.7], device='cuda') c = torch.tensor([0.7], device='cuda') def f(x, y): return x.type_as(y) trace = torch.jit.trace(f, (a, c)) s = str(trace.graph) self.run_pass('peephole', trace.graph) self.assertEqual(s, str(trace.graph)) trace = torch.jit.trace(f, (b, c)) self.run_pass('peephole', trace.graph) self.run_pass('dce', trace.graph) FileCheck().check_not("type_as").run(str(trace.graph))
def test_dist_backward(self): if self.rank != 0: return @torch.jit.script def dist_backward_script(context_id: int, loss: torch.Tensor): dist_autograd.backward(context_id, [loss]) FileCheck().check("dist_backward").run(str(dist_backward_script.graph)) with dist_autograd.context() as context_id: t1 = torch.rand(3, 3, requires_grad=True) t2 = torch.rand(3, 3, requires_grad=True) dst_worker_name = worker_name((self.rank + 1) % self.world_size) loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum() dist_backward_script(context_id, loss)
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly( forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertTrue(len(list(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) hy, cy = module(*inputs) (hy + cy).sum().backward() backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def test_stacktrace_interface_call(self): @torch.jit.interface class Forward(torch.nn.Module): def forward(self, x) -> torch.Tensor: pass def forwardError(self, x) -> torch.Tensor: pass class B(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x def forwardError(self, x): return self.call() + x def call(self): return torch.ones(-1) class A(torch.nn.Module): b : Forward def __init__(self): super().__init__() self.b = B() def forward(self): self.b.forward(torch.ones(1)) self.b.forwardError(torch.ones(1)) a = torch.jit.script(A()) torch._C._enable_mobile_interface_call_export() buffer = io.BytesIO(a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)) buffer.seek(0) mobile_module = _load_for_lite_interpreter(buffer) try: mobile_module() self.assertTrue(False) except RuntimeError as exp: FileCheck().check("Trying to create tensor with negative dimension") \ .check("Traceback of TorchScript") \ .check("self.b.forwardError").check_next("~~~~~~~~~~~~~~~~~~~ <--- HERE") \ .check("return self.call").check_next("~~~~~~~~~ <--- HERE") \ .check("return torch.ones").check_next("~~~~~~~~~~ <--- HERE").run(str(exp))
def validate_transform_conv1d_to_conv2d(self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape): module_instance = self scripted_model = torch.jit.script(module_instance) scripted_model.eval() input_data = torch.normal(1, 20, size=data_shape) ref_result = scripted_model(input_data) torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c) optimized_scripted_model = optimize_for_mobile(scripted_model) buffer = io.BytesIO() torch.jit.save(scripted_model, buffer) buffer.seek(0) deserialized_scripted_model = torch.jit.load(buffer) for pattern, v in pattern_count_transformed_map.items(): if (v == 0): FileCheck().check(pattern).run( deserialized_scripted_model.graph) elif (v == -1): FileCheck().check_not(pattern).run( deserialized_scripted_model.graph) else: FileCheck().check_count(pattern, v, exactly=True).run( deserialized_scripted_model.graph) transformed_result = deserialized_scripted_model(input_data) torch.testing.assert_allclose(ref_result, transformed_result, rtol=1e-2, atol=1e-3) optimized_buffer = io.BytesIO() torch.jit.save(optimized_scripted_model, optimized_buffer) optimized_buffer.seek(0) deserialized_optimized_scripted_model = torch.jit.load( optimized_buffer) for pattern, v in pattern_count_optimized_map.items(): if (v == 0): FileCheck().check(pattern).run( deserialized_optimized_scripted_model.graph) elif (v == -1): FileCheck().check_not(pattern).run( deserialized_optimized_scripted_model.graph) else: FileCheck().check_count(pattern, v, exactly=True).run( deserialized_optimized_scripted_model.graph) xnnpack_result = deserialized_optimized_scripted_model(input_data) torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
def test_fake_dispatch_keys(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([4]) f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU") f.run(torch._C._dispatch_key_set(x)) with torch.inference_mode(): x = torch.rand([4]) y = x + x FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y)) FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
def test_merges_down(self): # o x --> o # | ^ # \_________/ def fn(v, w, x, y): a = v * w b = torch.ones(int(y)) c = b * a return a, c graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1) num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3 # add moved down g_str = str(graph) FileCheck().check_not("aten::add").run(g_str[0:g_str.find("return")]) self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
def test_lists_insert(self): def successful_remove(): a: List[int] = [] a.insert(0, 1) a.insert(0, 2) a.insert(-10, 3) a.insert(-9, 4) a.insert(10, 5) return a fn = torch.jit.script(successful_remove) graph = fn.graph torch._C._jit_pass_remove_mutation(graph) torch._C._jit_pass_constant_propagation(graph) FileCheck().check("graph").check_next("Constant").check_next( "return").run(graph) self.assertEqual(successful_remove(), fn())
def checkGraphModeOp(self, module, data, quantized_op, tracing=False, debug=False, check=True, eval_mode=True, dynamic=False): if debug: print('Testing:', str(module)) qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} if eval_mode: module = module.eval() if dynamic: qconfig_dict = {'': default_dynamic_qconfig} inputs = data else: *inputs, target = data[0] model = get_script_module(module, tracing, inputs).eval() if debug: print('input graph:', model.graph) models = {} outputs = {} for d in [True, False]: # TODO: _test_only_eval_fn --> default_eval_fn if dynamic: models[d] = quantize_dynamic_jit(model, qconfig_dict, debug=d) # make sure it runs outputs[d] = models[d](inputs) else: # module under test can contain in-place ops, and we depend on # input data staying constant for comparisons data_copy = copy.deepcopy(data) models[d] = quantize_jit( model, qconfig_dict, test_only_eval_fn, [data_copy], inplace=False, debug=d) # make sure it runs outputs[d] = models[d](*inputs) if debug: print('debug graph:', models[True].graph) print('non debug graph:', models[False].graph) if check: # debug and non-debug option should have the same numerics self.assertEqual(outputs[True], outputs[False]) # non debug graph should produce quantized op FileCheck().check(quantized_op) \ .run(models[False].graph) return models[False]
def test_enum_value_types(self): global IntEnum class IntEnum(Enum): FOO = 1 BAR = 2 global FloatEnum class FloatEnum(Enum): FOO = 1.2 BAR = 2.3 global StringEnum class StringEnum(Enum): FOO = "foo as in foo bar" BAR = "bar as in foo bar" def supported_enum_types(a: IntEnum, b: FloatEnum, c: StringEnum): return (a.name, b.name, c.name) # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization # is supported. with torch._jit_internal._disable_emit_hooks(): scripted = torch.jit.script(supported_enum_types) FileCheck() \ .check("IntEnum") \ .check("FloatEnum") \ .check("StringEnum") \ .run(str(scripted.graph)) global TensorEnum class TensorEnum(Enum): FOO = torch.tensor(0) BAR = torch.tensor(1) def unsupported_enum_types(a: TensorEnum): return a.name with self.assertRaisesRegex( RuntimeError, "Cannot create Enum with value type 'Tensor'"): torch.jit.script(unsupported_enum_types)
def test_differentiable_graph_ops_requires_grad(self): x = torch.randn(8, 2, dtype=torch.float).requires_grad_() y = torch.randn(8, 2, dtype=torch.float) def t(x: torch.Tensor, y: torch.Tensor, flag: bool): o = x + 1.0 o1 = torch.relu(o) o = y + 1.5 o2 = torch.relu(o) o3 = o1 + o2 if flag: o = o1 + 1.0 oo1 = torch.relu(o) o = o2 + 2.5 oo2 = torch.relu(o) oo3 = oo1 + oo2 else: o = o1 * 1.0 oo1 = torch.relu(o) o = o2 * 2.0 oo2 = torch.relu(o) oo3 = oo1 + oo2 return o1, o2, o3, oo1, oo2, oo3 with enable_profiling_mode_for_profiling_tests(): t_jit = torch.jit.script(t) jit_o = t_jit(x, y, False) jit_o = t_jit(x, y, False) o = t(x, y, False) FileCheck().check("prim::DifferentiableGraph").run( t_jit.graph_for(x, y, False)) # validate the differentiableGraphOps are marking proper requires_grad for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.requires_grad, jit_oo.requires_grad) self.assertEqual(oo, jit_oo) # one more runs to trigger fusion jit_o = t_jit(x, y, False) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo.requires_grad, jit_oo.requires_grad) self.assertEqual(oo, jit_oo)
def test_list_keyword(self): def foo(): return list([1, 2, 3]), list(("a", "b")), list(range(5)), list("abcdefg") # noqa: C410 self.checkScript(foo, ()) def foo2(): x: List[int] = list() x.append(1) return x, self.checkScript(foo2, ()) def foo3(): return list(list("abc")) self.checkScript(foo3, ()) FileCheck().check_count("aten::list", 2, exactly=True).run(torch.jit.script(foo3).graph)
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) return forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly( forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) with enable_profiling_mode_for_profiling_tests(True): hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def test_finalize_for_linear_dynamic(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.fc = torch.nn.Linear(5, 5).float() def forward(self, x): return self.fc(x) data = [(torch.rand((1, 5), dtype=torch.float), torch.randint(0, 1, (1, ), dtype=torch.long)) for _ in range(2)] qconfig_dict = {'': default_dynamic_qconfig} model = torch.jit.script(M()).eval() model = quantize_dynamic_script(model, qconfig_dict, _test_only_eval_fn, [data]) FileCheck().check("quantized::linear_dynamic") \ .run(model.graph)
def test_merge_respects_aliasing(self): def fn(x, k, cond): y = x * 1.1 y = y * k y = y * 2.2 if bool(cond): z1 = y[0] z2 = y[1] z1.add_(3) out = z2 + k + 3.3 out = out * out return out graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1) # z2 did did not get merged into the subgraph FileCheck().check("prim::If").check("aten::select").check_next("aten::select")\ .check_next("aten::add_").check("Differentiable").run(graph) self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
def test_enum_name(self): global Color class Color(Enum): RED = 1 GREEN = 2 @torch.jit.script def enum_name(x: Color) -> str: return x.name FileCheck() \ .check("Color") \ .check_next("prim::EnumName") \ .check_next("return") \ .run(str(enum_name.graph)) self.assertEqual(enum_name(Color.RED), Color.RED.name) self.assertEqual(enum_name(Color.GREEN), Color.GREEN.name)
def test_enum_as_const(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def enum_const(x: Color) -> bool: return x == Color.RED FileCheck() \ .check("prim::Constant[value=__torch__.jit.test_enum.Color.RED]") \ .check_next("aten::eq") \ .check_next("return") \ .run(str(enum_const.graph)) self.assertEqual(enum_const(Color.RED), True) self.assertEqual(enum_const(Color.GREEN), False)
def test_enum_value(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def enum_value(x: Color) -> int: return x.value FileCheck() \ .check("Color") \ .check_next("prim::EnumValue") \ .check_next("return") \ .run(str(enum_value.graph)) self.assertEqual(enum_value(Color.RED), Color.RED.value) self.assertEqual(enum_value(Color.GREEN), Color.GREEN.value)
def test_enum_ivalue_type(self): class Color(Enum): RED = 1 GREEN = 2 make_global(Color) @torch.jit.script def is_color_enum(x: Any): return isinstance(x, Color) FileCheck() \ .check("prim::isinstance[types=[Enum<__torch__.jit.test_enum.Color>]]") \ .check_next("return") \ .run(str(is_color_enum.graph)) self.assertEqual(is_color_enum(Color.RED), True) self.assertEqual(is_color_enum(Color.GREEN), True) self.assertEqual(is_color_enum(1), False)
def test_prepare_dynamic_lstm(self): class M(torch.nn.Module): def __init__(self): super(M, self).__init__() self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) def forward(self, x): return self.lstm(x) from torch.quantization.observer import default_dynamic_quant_observer, _MinMaxTensorListObserver qconfig = QConfigDynamic(activation=default_dynamic_quant_observer, weight=_MinMaxTensorListObserver) m = torch.jit.script(M()) m = prepare_dynamic_script(m, {'': qconfig}) assert len(attrs_with_prefix(m.lstm, '_observer_')) == 1 FileCheck().check('_MinMaxTensorListObserver = prim::GetAttr[name="_observer_0') \ .check("aten::lstm") \ .check("return") \ .run(str(get_module_method(m, 'lstm', 'forward__0').graph))
def test_enum_comp(self): global Color class Color(Enum): RED = 1 GREEN = 2 def enum_comp(x: Color, y: Color) -> bool: return x == y # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization # is supported. with torch._jit_internal._disable_emit_hooks(): scripted_enum_comp = torch.jit.script(enum_comp) FileCheck().check("aten::eq").run(str(scripted_enum_comp.graph)) self.assertEqual(scripted_enum_comp(Color.RED, Color.RED), True) self.assertEqual(scripted_enum_comp(Color.RED, Color.GREEN), False)
def test_ignore_with_types(self): @torch.jit.ignore def fn(x: Dict[str, Optional[torch.Tensor]]): return x + 10 class M(torch.nn.Module): def __init__(self): super(M, self).__init__() def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor: self.dropout_modality(in_batch) fn(in_batch) return torch.tensor(1) @torch.jit.ignore def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]: return in_batch sm = torch.jit.script(M()) FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
def test_shape_analysis(self): @torch.jit.script def foo(x, y): return x * y inputs = list(foo.graph.inputs()) def prop_shapes_on_graph(inp0, inp1): inputs[0].setType(inputs[0].type().with_sizes(inp0)) inputs[1].setType(inputs[1].type().with_sizes(inp1)) torch._C._jit_pass_propagate_shapes_on_graph(foo.graph) prop_shapes_on_graph([1, 6, 5], [1, 7, 1, 5]) FileCheck().check("1, 7, 6, 5").run(foo.graph) # None implicitly creates a new symbolic symbol prop_shapes_on_graph([None, None], [None, None, None]) output_shape = foo.graph.findNode( "aten::mul").output().type().symbolic_sizes() inp0_shape = inputs[0].type().symbolic_sizes() inp1_shape = inputs[1].type().symbolic_sizes() # output shape dim 0 should be taken from the second inp dim0 # other two dims we cannot infer and are given a new symbolic shape self.assertEqual(output_shape[0], inp1_shape[0]) self.assertFalse(output_shape[1] in inp0_shape + inp1_shape) self.assertFalse(output_shape[2] in inp0_shape + inp1_shape) # XXX: symbolic shapes are represented with an increasing counter of unique # values, use `_new_symbolic_shape_symbol` api instead of specifying negative # dimensions directly so there is no chance of collision between manual number # and current counter value. sym1 = torch._C._new_symbolic_shape_symbol() sym2 = torch._C._new_symbolic_shape_symbol() sym3 = torch._C._new_symbolic_shape_symbol() prop_shapes_on_graph([sym1, 1, sym3], [1, sym2, sym3]) output_shape = foo.graph.findNode( "aten::mul").output().type().symbolic_sizes() self.assertEqual(output_shape[0], sym1) self.assertEqual(output_shape[1], sym2) self.assertEqual(output_shape[2], sym3)