def test_fetch(self): attrs_for_lowering: Dict[str, List[str]] = { "torch.nn.modules.conv.Conv2d": [ "weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode" ], "torch.nn.modules.batchnorm.BatchNorm2d": ["weight", "bias", "running_mean", "running_var", "eps"], } class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 3, 2) self.bn = torch.nn.BatchNorm2d(3) def forward(self, a): a = self.conv(a) a += a return self.bn(a) mod = TestModule() traced = symbolic_trace(mod) lift_lowering_attrs_to_nodes(traced) for node in traced.graph.nodes: if node.op == "call_module": assert hasattr(node, "attrs_for_lowering") para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] # node.attrs_for_lowering has an addition field of class name assert len(para_list) + 1 == len(node.attrs_for_lowering) for p_name in para_list: assert p_name in node.attrs_for_lowering
def test_subgraph_uniquename(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a, b, c, d): add_1 = a + b add_2 = add_1 + c linear_1 = self.linear(add_1) add_3 = add_2 + d add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) mm = MyModule() traced = symbolic_trace(mm) def split_cb(node: torch.fx.Node): if node.name == 'a' or node.name == 'b' or node.name == 'add': return 0 else: return 1 module_with_submodule = split_module(traced, mm, split_cb) self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
def test_large_node_error(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): linear = self.linear(a) add = linear + a return add m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) partitioner = Partitioner() devices = [ Device("dev_0", 40, 0), Device("dev_1", 40, 0), Device("dev_2", 40, 0), Device("dev_3", 40, 0), Device("dev_4", 40, 0) ] partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) catch_runtime_error = False try: ret = partitioner.partition_graph(traced, m, partitioner_config) except RuntimeError: catch_runtime_error = True assert catch_runtime_error
def test_size_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) self.c = torch.rand(4) def forward(self, a, b): add_1 = a + b linear = self.linear(add_1) add_2 = linear + self.c return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device("dev_0", 125, 0), Device("dev_1", 125, 1), Device("dev_2", 125, 2) ] partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) for i, node in enumerate(dag.nodes): assert node.logical_device_ids == [i]
def test_partition_combining(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear_0 = torch.nn.Linear(4, 4) def forward(self, a, b): add_1 = a + b c = self.linear_0(a) add_2 = c + add_1 return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2) ] partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) assert len(module_with_submodules.graph.nodes) == 5
def test_partition_device_mapping(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): b = torch.rand(4) add_1 = a + b linear_1 = self.linear(add_1) add_2 = torch.rand(4) + a add_3 = add_2 + linear_1 return add_3 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) partitioner = Partitioner() devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) for i, node in enumerate(dag.nodes): if i == 1: assert node.logical_device_ids == [1] else: assert node.logical_device_ids == [0]
def test_to_folder(self): class Test(torch.nn.Module): def __init__(self): super(Test, self).__init__() self.W = torch.nn.Parameter(torch.randn(2)) self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) self.linear = torch.nn.Linear(2, 2) self.attr = torch.randn(2) self.register_buffer('attr2', torch.randn(2)) def forward(self, x): return self.linear( self.seq(self.W + self.attr + self.attr2 + x)) mod = symbolic_trace(Test()) module_name = 'Foo' import tempfile from pathlib import Path with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) mod.to_folder(tmp_dir, module_name) # Recipe taken from here: # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly import importlib.util spec = importlib.util.spec_from_file_location( module_name, tmp_dir / '__init__.py') module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) t = torch.randn(2, 2) self.assertEqual(module.Foo()(t), mod(t))
def test_partition_node_manipulation(self): class TestModule(torch.nn.Module): def forward(self, a, b): add_1 = a + b add_2 = add_1 + torch.rand(4) add_3 = add_2 + torch.rand(4) return add_3 m = TestModule() traced = symbolic_trace(m) a, b = torch.rand(4), torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [Device('dev_0', 1000, 0)] partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) partition = partitioner.partitions[0] assert partition.used_mem_bytes == 112 # Select add_2 node to remove selected_node = None for node in partition.nodes: if node.name == 'add_2': selected_node = node partition.remove_node(selected_node) assert (partition.used_mem_bytes == 80)
def test_partition_combining(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): b = torch.rand(4) add_1 = a + b linear_1 = self.linear(add_1) add_2 = torch.rand(4) + a add_3 = add_2 + linear_1 return add_3 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a]) partitioner = Partitioner() devices = [Device('dev_0', 120, 0), Device('dev_1', 144, 1)] partitioner_config = PartitionerConfig(devices, is_sparse_nn=False) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) assert dag.nodes[0].logical_device_ids == [0] assert dag.nodes[0].size_bytes == 80 assert dag.nodes[1].logical_device_ids == [1] assert dag.nodes[1].size_bytes == 144
def test_size_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a, b): add_1 = a + b linear = self.linear(add_1) e = torch.rand(4) add_2 = linear + e return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device('dev_0', 125), Device('dev_1', 125), Device('dev_2', 125) ] ret = partitioner.partition_graph(traced, m, devices) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a, b), module_with_submodules(a, b)) assert len(module_with_submodules.graph.nodes) == 7
def test_sparse_nn_partition(self): class MyRecommendationModule(torch.nn.Module): def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): layers = torch.nn.ModuleList() for _ in range(num_of_layers): ll = torch.nn.Linear(input_size, output_size) layers.append(ll) layers.append(torch.nn.ReLU()) return layers def __init__(self): super(MyRecommendationModule, self).__init__() layers = self.create_mlp(4, 4, 4) self.bottom_layers = torch.nn.Sequential(*layers) layers = self.create_mlp(3, 24, 24) self.top_layers = torch.nn.Sequential(*layers) self.embedding_layers = torch.nn.ModuleList() el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True) self.embedding_layers.append(el) for i in range(3): el = torch.nn.EmbeddingBag(1000000, 4, mode='sum', sparse=True) self.embedding_layers.append(el) el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True) self.embedding_layers.append(el) def forward(self, a, b, offset): x = self.bottom_layers(a) y = [] c = [] for i in range(len(self.embedding_layers)): temp = torch.randint(10, (8, )) c.append(temp + b) for i in range(len(self.embedding_layers)): if i % 2 == 0: y.append(self.embedding_layers[i](c[i], offset)) else: y.append(self.embedding_layers[i](torch.randint(10, (8, )), offset)) z = torch.cat([x] + y, dim=1) p = self.top_layers(z) return p m = MyRecommendationModule() a = torch.rand(2, 4) b = torch.randint(10, (8, )) offset = torch.randint(1, (2, )) traced = symbolic_trace(m) GraphManipulation.get_size_of_all_nodes(traced, [a, b, offset]) devices = [ Device('dev_0', 33000000, 0), Device('dev_1', 33000000, 1), Device('dev_2', 33000000, 2) ] partitioner_config = PartitionerConfig(devices, is_sparse_nn=True) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) assert len(module_with_submodules.graph.nodes) == 24
def test_subgraph_trivial_resnet(self): # Smoke test trivially splitting resnet into 1 partition works # There was an issue before causing submodule names to be aliased m = resnet18() traced = symbolic_trace(m) a = torch.rand(64, 3, 7, 7) module_with_submodules = split_module(traced, m, lambda node: 0) module_with_submodules(a)
def test_cost_aware_partition(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + torch.rand(4) add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, Nodelatency] = {} for node in fx_module.graph.nodes: if node.op not in {'output', 'placeholder', 'get_attr'}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 1) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size) return node_to_latency_mapping m = MyModule() traced = symbolic_trace(m) a = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a]) devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2), Device('dev_3', 125, 3) ] node_to_latency_mapping = get_node_to_latency_mapping(traced) partitioner_config = PartitionerConfig( devices, is_sparse_nn=False, is_cost_aware=True, transfer_rate_bytes_per_sec=2, node_to_latency_mapping=node_to_latency_mapping) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping) critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, partitioner_config.transfer_rate_bytes_per_sec) assert critical_path_latency_sec == 160.
def test_partition_latency(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + add_3 return add_4 def get_node_to_latency_mapping(fx_module: GraphModule): """Given a fx module, generate node latency for each node based on the size of each node """ node_to_latency_mapping: Dict[Node, NodeLatency] = {} for node in fx_module.graph.nodes: if node.op not in {"output", "placeholder", "get_attr"}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 2.0 * node.size_bytes.total_size ) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size ) return node_to_latency_mapping m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] partitioner = Partitioner() partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping ) for p in partition_to_latency_mapping: if p.partition_id == 0: assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) else: assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) transfer_rate_bytes_per_sec = 2 critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec ) assert critical_path_latency_sec == 208.0
def test_normalize_binary_operators(self): ops_to_test = { torch.add, torch.mul, torch.sub, torch.div, torch.floor_divide, torch.remainder, torch.eq, torch.ne, torch.lt, torch.le, torch.gt, torch.ge, } # Test Tensor/Tensor callsite for op in ops_to_test: class WrapperMod(torch.nn.Module): def forward(self, x, y): return op(x, y) traced = symbolic_trace(WrapperMod()) normalized = NormalizeOperators(traced).transform() x, y = torch.randn(3, 4), torch.randn(3, 4) torch.testing.assert_allclose(traced(x, y), normalized(x, y)) self.assertFalse( any(n.target in ops_to_test for n in normalized.graph.nodes)) # Test Tensor/scalar callsite for op in ops_to_test: class WrapperMod(torch.nn.Module): def forward(self, x): return op(x, 42) traced = symbolic_trace(WrapperMod()) normalized = NormalizeOperators(traced).transform() x = torch.randn(3, 4) torch.testing.assert_allclose(traced(x), normalized(x)) self.assertFalse( any(n.target in ops_to_test for n in normalized.graph.nodes))
def test_conv_bn_fusion(self): rn18 = resnet18().eval() traced = symbolic_trace(rn18) fused = fuse(traced) self.assertTrue(all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())) N, C, H, W = 20, 3, 224, 224 inp = torch.randn(N, C, H, W) self.assertEqual(fused(inp), rn18(inp))
def test_subgraph_creation(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x, y): z = self.linear(x + self.param).clamp(min=0.0, max=1.0) w = self.linear(y).clamp(min=0.0, max=1.0) return z + w # symbolically trace model my_module = MyModule() my_module_traced = symbolic_trace(my_module) # random mod partitioning partition_counter = 0 NPARTITIONS = 3 # Add some random meta info to make sure it is kept around. for node in my_module_traced.graph.nodes: if node.op != "output": node.meta["test_meta_info"] = True def mod_partition(node: Node): nonlocal partition_counter partition = partition_counter % NPARTITIONS partition_counter = (partition_counter + 1) % NPARTITIONS return partition # split module in module with submodules module_with_submodules = split_module(my_module_traced, my_module, mod_partition) # Check that test_meta_info was still on all nodes. submodules = dict(module_with_submodules.named_modules()) for node in module_with_submodules.graph.nodes: if node.op == "call_module": submod = submodules[node.target] self.assertTrue(isinstance(submod, torch.fx.GraphModule)) for submod_node in submod.graph.nodes: if submod_node.op != "output": stored_op = submod_node.meta.get("test_meta_info") self.assertTrue(stored_op is not None and stored_op) x = torch.rand(3, 4) y = torch.rand(3, 4) orig_out = my_module_traced(x, y) submodules_out = module_with_submodules(x, y) self.assertEqual(orig_out, submodules_out)
def test_kl_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) self.b = torch.rand(4) self.c = torch.rand(4) self.d = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = add_1 + self.c linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + self.d add_5 = add_3 + add_4 return add_4 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) transfer_rate_bytes_per_sec = 2 devices = [ Device('dev_0', 200, 0), Device('dev_1', 200, 1), Device('dev_2', 200, 2), Device('dev_3', 200, 3) ] partitioner = Partitioner() partitioner_config = PartitionerConfig( devices, mode=PartitionMode.kl_based, transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec, node_to_latency_mapping=node_to_latency_mapping ) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) dag = ret.dag assert dag.nodes[0] == 176 assert dag.nodes[1] == 112 partition_to_latency_mapping = get_partition_to_latency_mapping( partitioner.partitions, node_to_latency_mapping ) cost = get_latency_of_partitioned_graph( partitioner.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec ) assert cost == 208.
def test_partition_latency(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_4 = add_2 + linear_1 add_5 = add_2 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): """Given a fx module, generate node latency for each node based on the size of each node """ node_to_latency_mapping: Dict[Node, NodeLatency] = {} for node in fx_module.graph.nodes: if node.op not in {'output', 'placeholder', 'get_attr'}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 2. * node.size_bytes.total_size) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size) return node_to_latency_mapping m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) devices = [Device('dev_0', 200, 0), Device('dev_1', 200, 0)] partitioner = Partitioner() partitioner_config = PartitionerConfig(devices, False) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_latency_0 = get_latency_of_one_partition( partitions[0], node_to_latency_mapping) assert (128., 80., 160.) == partition_latency_0 partition_latency_1 = get_latency_of_one_partition( partitions[1], node_to_latency_mapping) assert (16., 32., 32) == partition_latency_1
def test_annotate_returns_with_schema(self): m = resnet18() traced_modules = symbolic_trace(m) traced_modules_annotated = AnnotateTypesWithSchema( traced_modules).transform() for node in traced_modules_annotated.graph.nodes: if node.type is None: check = (node.op, node.target) self.assertTrue( check in {('placeholder', 'x'), ('call_function', operator.add ), ('call_function', torch.flatten), ('output', 'output')}) # Smoke test torchscript compilation since now we're emitting type annotations torch.jit.script(traced_modules_annotated) class FunctionalTracer(torch.fx.Tracer): def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: # `leaves` contains the set of standard `nn.Modules` that are not # currently symbolically traceable. Ideally this set would be empty leaves = set([torch.nn.BatchNorm2d]) return type(m) in leaves traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) traced_functionals_annotated = AnnotateTypesWithSchema( traced_functionals).transform() for node in traced_functionals_annotated.graph.nodes: if node.type is None: check = (node.op, node.target) excluded_nodes = { ('placeholder', 'x'), ('call_function', torch.conv2d), # Return type differs based on boolean dispatch :( ('call_function', torch.nn.functional.max_pool2d), ('call_function', operator.add), ('call_function', torch.flatten), ('output', 'output'), } self.assertTrue(check in excluded_nodes) # Smoke test torchscript compilation since now we're emitting type annotations torch.jit.script(traced_functionals_annotated)
def test_replace_target_nodes_with(self): class testModule(torch.nn.Module): def forward(self, a, b): return a + b m = testModule() traced = symbolic_trace(m) input1 = torch.randn(1) input2 = torch.randn(1) assert (input1 + input2) == traced(input1, input2) graph_manipulation.replace_target_nodes_with( fx_module=traced, old_op="call_function", old_target=operator.add, new_op="call_function", new_target=operator.mul, ) assert (input1 * input2) == traced(input1, input2)
def test_lack_of_devices(self): class TestModule(torch.nn.Module): def forward(self, a, b): return a + b m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) catch_runtime_error = False try: ret = partitioner.partition_graph(traced, m, partitioner_config) except RuntimeError: catch_runtime_error = True assert catch_runtime_error
def test_find_single_partition(self): class TestModule(torch.nn.Module): def forward(self, a, b): return a + b m = TestModule() traced = symbolic_trace(m) a = torch.rand(1) b = torch.rand(1) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device('dev_0', 125), Device('dev_1', 125), Device('dev_2', 125) ] ret = partitioner.partition_graph(traced, m, devices) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a, b), module_with_submodules(a, b))
def test_find_single_partition(self): class TestModule(torch.nn.Module): def forward(self, a, b): return a + b m = TestModule() traced = symbolic_trace(m) a = torch.rand(1) b = torch.rand(1) graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device("dev_0", 125, 0), Device("dev_1", 125, 1), Device("dev_2", 125, 2) ] partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) assert dag.nodes[0].logical_device_ids == [0]
def test_aot_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.b = torch.rand(4) self.c = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = self.c + add_1 return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) node_to_partition_id = {} partition_to_logical_devices = {} count = 0 GraphManipulation.get_size_of_all_nodes(traced, [a]) for node in traced.graph.nodes: if node.op not in {'placeholder', 'get_attr', 'output'}: node_to_partition_id[node] = count partition_to_logical_devices[count] = [0] count += 1 devices = [Device('dev_0', 200, 0)] partitioner_config = PartitionerConfig( devices=devices, mode=PartitionMode.aot_based, node_to_partition_mapping=node_to_partition_id, partition_to_logical_device_mapping=partition_to_logical_devices ) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(module_with_submodules(a), traced(a)) for node in dag.nodes: assert node.size_bytes == 48 assert node.logical_device_ids == [0]
def test_subgraph_creation(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x, y): z = self.linear(x + self.param).clamp(min=0.0, max=1.0) w = self.linear(y).clamp(min=0.0, max=1.0) return z + w # symbolically trace model my_module = MyModule() my_module_traced = symbolic_trace(my_module) # random mod partitioning partition_counter = 0 NPARTITIONS = 3 def mod_partition(node: Node): nonlocal partition_counter partition = partition_counter % NPARTITIONS partition_counter = (partition_counter + 1) % NPARTITIONS return partition # split module in module with submodules module_with_submodules = split_module(my_module_traced, my_module, mod_partition) x = torch.rand(3, 4) y = torch.rand(3, 4) orig_out = my_module_traced(x, y) submodules_out = module_with_submodules(x, y) self.assertEqual(orig_out, submodules_out)
def test_normalize_modules_exhaustive(self): """ Exhaustively test `NormalizeArgs` on all standard torch.nn Module classes """ for test_params in module_tests + new_module_tests: if 'constructor' not in test_params: constructor = getattr(torch.nn, test_params['module_name']) else: constructor = test_params['constructor'] if 'constructor_args' not in test_params: args = () else: args = test_params['constructor_args'] mod = constructor(*args) # Skip modules that are not standard `torch.nn` # instances, including functionals. (functionals # are tested in test_normalize_args) if mod.__class__.__name__ not in dir(torch.nn): continue if 'input_fn' not in test_params: inputs = torch.randn(test_params['input_size']) else: inputs = test_params['input_fn']() if not isinstance(inputs, (tuple, list)): inputs = (inputs, ) params = ', '.join(f'v{i}' for i in range(len(inputs))) # Generate a class to wrap this standard `nn.Module` instance test_classname = f'Test{mod.__class__.__name__}' test_mod_code = f""" class {test_classname}(torch.nn.Module): def __init__(self, mod): super().__init__() self.mod = mod def forward(self, {params}): return self.mod({params}) """ gbls = {'torch': torch} exec(test_mod_code, gbls) test_instance = gbls[test_classname](mod) traced = symbolic_trace(test_instance) # Now actually test arg normalization! traced = NormalizeArgs(traced).transform() # These Modules have an RNG in their forward, so testing # correctness by comparing outputs is not correct. Skip that # check for these stochastic_modules = { 'FractionalMaxPool2d', 'FractionalMaxPool3d', 'RReLU' } if mod.__class__.__name__ not in stochastic_modules: self.assertEqual(traced(*inputs), mod(*inputs)) # Ensure all args/kwargs are normalized into kwargs modules = dict(traced.named_modules()) for node in traced.graph.nodes: if node.op == 'call_module': submod_class = modules[node.target].__class__ nn_class = getattr(torch.nn, submod_class.__name__) if submod_class == nn_class: self.assertEqual(len(node.args), 0)
def merge_matmul(in_mod: torch.nn.Module): """ A graph transformation that merges matrix multiplication operations that share the same right-hand side operand into one large matrix multiplication. ____ _________ _________ ---- | | | | M| A * C | M| A | T| B | * K| C | = |---------| ---- , | | | | T| B * C | K ---- --------- --------- K R R """ gm = symbolic_trace(in_mod) rhs_users: Dict[Node, List[Node]] = {} lhs_users: Dict[Node, List[Node]] = {} # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to # the matmul of which they are the LHS/RHS. for node in gm.graph.nodes: if node.op != "call_function" or node.target is not torch.matmul: continue lhs, rhs = node.args # TODO: Properly handle aliasing caused by get_attr. For now, # use the attribute name as the operand if the node is a # get_attr. lhs = lhs.target if lhs.op == "get_attr" else lhs rhs = rhs.target if rhs.op == "get_attr" else rhs lhs_users.setdefault(lhs, []).append(node) rhs_users.setdefault(rhs, []).append(node) for rhs, mms in rhs_users.items(): # There must be at least matmuls for a merge to make sense. if len(mms) < 2: continue # All matmuls must not depend on each other directly or indirectly # in order for the merge to be possible. if not are_nodes_independent(mms): continue lhs_vals = [mm.args[0] for mm in mms] # Merge the matmul. # Collect a list of LHS operands and the single RHS operand. lhs = [ gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals ] rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs # Concatenate all the LHS operands. merge_mm_cat = gm.graph.call_function(torch.cat, (lhs, ), {}) # Multiply the concatenated LHS operands with the one RHS. This will produce # the same results as all the individual matmuls involving rhs in the original graph, # but they will all be concatenated together. merge_mm = gm.graph.call_function(torch.matmul, ( merge_mm_cat, rhs, ), {}) # Split the result of the merged matmul using the shapes of the LHS operands # to ascertain how large each chunk should be. merge_mm_sizes = [ gm.graph.call_function(get_first_dim, (l, ), {}) for l in lhs ] merge_mm_split = gm.graph.call_function(torch.split, (merge_mm, merge_mm_sizes), {}) merge_mm_res = [ gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) for out in range(len(lhs)) ] # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. for old, new in zip(mms, merge_mm_res): old.replace_all_uses_with(new) gm.graph.erase_node(old) # All of the new nodes created above were inserted at the end, so we need to sort # the nodes topologically to make sure all definitions precede uses. legalize_graph(gm) gm.recompile() gm.graph.lint(in_mod) return gm
def test_serialize_graph(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) self.e = torch.rand(4) self.conv = torch.nn.Conv2d(3, 3, 2, bias=False) def forward(self, a, b, c): add_1 = a + b conv1 = self.conv(c) linear = self.linear(add_1 + conv1) add_2 = linear + self.e return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) c = torch.rand(3, 3, 2, 2) graph_manipulation.get_size_of_all_nodes(traced, [a, b, c]) partitioner = Partitioner() devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)] partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules # Fix for now to add type/shape to output for node in traced.graph.nodes: if node.op == "output": node.meta['tensor_meta'] = extract_tensor_metadata(a) for mod in module_with_submodules.modules(): if isinstance(mod, GraphModule): for node in mod.graph.nodes: node.meta['tensor_meta'] = extract_tensor_metadata(a) for node in module_with_submodules.graph.nodes: node.meta['tensor_meta'] = extract_tensor_metadata(a) weights1 = {} weights2 = {} serialized_graph1 = graph_manipulation.serialize_module( traced, weights1) serialized_graph2 = graph_manipulation.serialize_module( module_with_submodules, weights2) assert len(weights1) == 4 assert len(weights2) == 4 assert len(serialized_graph1["nodes"]) == 10 assert len(serialized_graph1["weights"]) == 4 assert len(serialized_graph1["modules"]) == 0 assert len(serialized_graph2["nodes"]) == 6 assert len(serialized_graph2["weights"]) == 4 assert len(serialized_graph2["modules"]) == 1 assert serialized_graph1["weights"]["linear.weight"][ "shape"] == "[4, 4]" assert (serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32") assert (serialized_graph1["weights"]["linear.weight"]["is_quantized"] is False) assert serialized_graph1["nodes"][0]["shape"] == "[4]" assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32" assert serialized_graph1["nodes"][0]["target"] == "a" assert serialized_graph1["nodes"][0]["op_code"] == "placeholder" assert serialized_graph1["nodes"][0]["name"] == "a" assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_1" assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True # Test quantization info serialization. x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32) q_tensor_channel = torch.quantize_per_channel( x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) result = graph_manipulation.serialize_tensor_quantization(q_tensor) result2 = graph_manipulation.serialize_tensor_quantization( q_tensor_channel) assert result["qscheme"] == "torch.per_tensor_affine" assert result["q_scale"] == 1.0 assert result2["qscheme"] == "torch.per_channel_affine" assert len(result2["q_per_channel_scales"]) == 2
def test_cost_aware_partition(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + torch.rand(4) add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, Nodelatency] = {} for node in fx_module.graph.nodes: if node.op not in {'output', 'placeholder', 'get_attr'}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 1) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size) return node_to_latency_mapping m = MyModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2), Device('dev_3', 125, 3) ] node_to_latency_mapping = get_node_to_latency_mapping(traced) partitioner_config = PartitionerConfig( devices, mode=PartitionMode.cost_aware, transfer_rate_bytes_per_sec=2, node_to_latency_mapping=node_to_latency_mapping) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping) critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, partitioner_config.transfer_rate_bytes_per_sec) assert critical_path_latency_sec == 160. def test_kl_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) self.b = torch.rand(4) self.c = torch.rand(4) self.d = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = add_1 + self.c linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + self.d add_5 = add_3 + add_4 return add_4 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) transfer_rate_bytes_per_sec = 2 devices = [ Device('dev_0', 200, 0), Device('dev_1', 200, 1), Device('dev_2', 200, 2), Device('dev_3', 200, 3) ] partitioner = Partitioner() partitioner_config = PartitionerConfig( devices, mode=PartitionMode.kl_based, transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec, node_to_latency_mapping=node_to_latency_mapping) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) dag = ret.dag assert dag.nodes[0] == 176 assert dag.nodes[1] == 112 partition_to_latency_mapping = get_partition_to_latency_mapping( partitioner.partitions, node_to_latency_mapping) cost = get_latency_of_partitioned_graph( partitioner.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) assert cost == 208. def test_aot_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.b = torch.rand(4) self.c = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = self.c + add_1 return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) node_to_partition_id = {} partition_to_logical_devices = {} count = 0 GraphManipulation.get_size_of_all_nodes(traced, [a]) for node in traced.graph.nodes: if node.op not in {'placeholder', 'get_attr', 'output'}: node_to_partition_id[node] = count partition_to_logical_devices[count] = [0] count += 1 devices = [Device('dev_0', 200, 0)] partitioner_config = PartitionerConfig( devices=devices, mode=PartitionMode.aot_based, node_to_partition_mapping=node_to_partition_id, partition_to_logical_device_mapping=partition_to_logical_devices ) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(module_with_submodules(a), traced(a)) for node in dag.nodes: assert node.size_bytes == 48 assert node.logical_device_ids == [0] def test_replace_target_nodes_with(self): class testModule(torch.nn.Module): def forward(self, a, b): return a + b m = testModule() traced = symbolic_trace(m) input1 = torch.randn(1) input2 = torch.randn(1) assert (input1 + input2) == traced(input1, input2) graph_manipulation.replace_target_nodes_with( fx_module=traced, old_op="call_function", old_target=operator.add, new_op="call_function", new_target=operator.mul, ) assert (input1 * input2) == traced(input1, input2)