def test_cost_aware_partition(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + torch.rand(4) add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, Nodelatency] = {} for node in fx_module.graph.nodes: if node.op not in {'output', 'placeholder', 'get_attr'}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 1) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size) return node_to_latency_mapping m = MyModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2), Device('dev_3', 125, 3) ] node_to_latency_mapping = get_node_to_latency_mapping(traced) partitioner_config = PartitionerConfig( devices, mode=PartitionMode.cost_aware, transfer_rate_bytes_per_sec=2, node_to_latency_mapping=node_to_latency_mapping) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping) critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, partitioner_config.transfer_rate_bytes_per_sec) assert critical_path_latency_sec == 160. def test_kl_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) self.b = torch.rand(4) self.c = torch.rand(4) self.d = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = add_1 + self.c linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + self.d add_5 = add_3 + add_4 return add_4 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) transfer_rate_bytes_per_sec = 2 devices = [ Device('dev_0', 200, 0), Device('dev_1', 200, 1), Device('dev_2', 200, 2), Device('dev_3', 200, 3) ] partitioner = Partitioner() partitioner_config = PartitionerConfig( devices, mode=PartitionMode.kl_based, transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec, node_to_latency_mapping=node_to_latency_mapping) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) dag = ret.dag assert dag.nodes[0] == 176 assert dag.nodes[1] == 112 partition_to_latency_mapping = get_partition_to_latency_mapping( partitioner.partitions, node_to_latency_mapping) cost = get_latency_of_partitioned_graph( partitioner.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec) assert cost == 208. def test_aot_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.b = torch.rand(4) self.c = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = self.c + add_1 return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) node_to_partition_id = {} partition_to_logical_devices = {} count = 0 GraphManipulation.get_size_of_all_nodes(traced, [a]) for node in traced.graph.nodes: if node.op not in {'placeholder', 'get_attr', 'output'}: node_to_partition_id[node] = count partition_to_logical_devices[count] = [0] count += 1 devices = [Device('dev_0', 200, 0)] partitioner_config = PartitionerConfig( devices=devices, mode=PartitionMode.aot_based, node_to_partition_mapping=node_to_partition_id, partition_to_logical_device_mapping=partition_to_logical_devices ) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(module_with_submodules(a), traced(a)) for node in dag.nodes: assert node.size_bytes == 48 assert node.logical_device_ids == [0] def test_replace_target_nodes_with(self): class testModule(torch.nn.Module): def forward(self, a, b): return a + b m = testModule() traced = symbolic_trace(m) input1 = torch.randn(1) input2 = torch.randn(1) assert (input1 + input2) == traced(input1, input2) graph_manipulation.replace_target_nodes_with( fx_module=traced, old_op="call_function", old_target=operator.add, new_op="call_function", new_target=operator.mul, ) assert (input1 * input2) == traced(input1, input2)
def test_sparse_nn_partition(self): class MyRecommendationModule(torch.nn.Module): def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): layers = torch.nn.ModuleList() for _ in range(num_of_layers): ll = torch.nn.Linear(input_size, output_size) layers.append(ll) layers.append(torch.nn.ReLU()) return layers def __init__(self): super(MyRecommendationModule, self).__init__() layers = self.create_mlp(4, 4, 4) self.bottom_layers = torch.nn.Sequential(*layers) layers = self.create_mlp(3, 24, 24) self.top_layers = torch.nn.Sequential(*layers) self.embedding_layers = torch.nn.ModuleList() el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) for i in range(3): el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) def forward(self, a, b, offset): x = self.bottom_layers(a) y = [] c = [] for i in range(len(self.embedding_layers)): temp = torch.randint(10, (8, )) c.append(temp + b) for i in range(len(self.embedding_layers)): if i % 2 == 0: y.append(self.embedding_layers[i](c[i], offset)) else: y.append(self.embedding_layers[i](torch.randint( 10, (8, )), offset)) z = torch.cat([x] + y, dim=1) p = self.top_layers(z) return p m = MyRecommendationModule() a = torch.rand(2, 4) b = torch.randint(10, (8, )) offset = torch.randint(1, (2, )) traced = symbolic_trace(m) graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset]) devices = [ Device("dev_0", 33000000, 0), Device("dev_1", 33000000, 1), Device("dev_2", 33000000, 2) ] partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) assert len(module_with_submodules.graph.nodes) == 24
def partition_graph( self, fx_module: GraphModule, torch_module: torch.nn.Module, partitioner_config: PartitionerConfig) -> PartitionResult: """Given the fx module, torch module and partitioner_config, find the partitions, do the partitions, and then return a DAG and a new fx module with submodule nodes (partitions) """ self.graph_module = fx_module self.torch_module = torch_module self.devices = partitioner_config.devices if len(self.devices) == 0: raise RuntimeError('No devices') # Tag the size in bytes to all nodes in the graph_module. get_size_of_all_nodes(self.graph_module) # Check if there are op nodes in the fx module nodes = self.graph_module.graph.nodes if all(node.op in {'placeholder', 'get_attr', 'output'} for node in nodes): raise RuntimeError( 'No Partition since no operations in the module') # Calculate total size of the fx module total_size_of_graph = 0 for node in nodes: if node.op == 'output': break total_size_of_graph += node.size_bytes.total_size # Find the device with the max mem size device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) # AOT based partition if partitioner_config.mode == PartitionMode.aot_based: self.aot_based_partition( partitioner_config.node_to_partition_mapping, partitioner_config.partition_to_logical_device_mapping) # Single partition if the whole module can be fit into one device elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: self.find_single_partition(total_size_of_graph) elif total_size_of_graph > sum( [d.available_mem_bytes for d in self.devices]): raise RuntimeError('Devices have no enough memory for the module') else: # Sparse nn based partition if partitioner_config.mode == PartitionMode.sparse_nn: available_mem_bytes = self.devices[0].available_mem_bytes if not all(device.available_mem_bytes == available_mem_bytes for device in self.devices): raise RuntimeError( 'All devices must have same memory size!') # sparse_nn_partition only support same memory size # TODO: add different size support for sparse_nn_partition self.sparse_nn_partition(available_mem_bytes) # Cost aware partition elif partitioner_config.mode == PartitionMode.cost_aware: self.cost_aware_partition( partitioner_config.transfer_rate_bytes_per_sec, partitioner_config.node_to_latency_mapping) # KL based partition elif partitioner_config.mode == PartitionMode.kl_based: self.kl_based_partition( partitioner_config.transfer_rate_bytes_per_sec, partitioner_config.node_to_latency_mapping) else: self.size_based_partition() module_with_submodules = self.do_partition() # The DAG contains DAGNodes with info of each partition's input nodes, output nodes # and how partitions are connected. dag = self.dump_dag(module_with_submodules) ret = PartitionResult(dag, module_with_submodules) return ret
def test_serialize_graph(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) self.e = torch.rand(4) self.conv = torch.nn.Conv2d(3, 3, 2, bias=False) def forward(self, a, b, c): add_1 = a + b conv1 = self.conv(c) linear = self.linear(add_1 + conv1) add_2 = linear + self.e return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) c = torch.rand(3, 3, 2, 2) graph_manipulation.get_size_of_all_nodes(traced, [a, b, c]) partitioner = Partitioner() devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)] partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules # Fix for now to add type/shape to output for node in traced.graph.nodes: if node.op == "output": node.shape = a.shape node.dtype = a.dtype for mod in module_with_submodules.modules(): if isinstance(mod, GraphModule): for node in mod.graph.nodes: node.shape = a.shape node.dtype = a.dtype for node in module_with_submodules.graph.nodes: node.shape = a.shape node.dtype = a.dtype weights1 = {} weights2 = {} serialized_graph1 = graph_manipulation.serialize_module( traced, weights1) serialized_graph2 = graph_manipulation.serialize_module( module_with_submodules, weights2) assert len(weights1) == 4 assert len(weights2) == 4 assert len(serialized_graph1["nodes"]) == 10 assert len(serialized_graph1["weights"]) == 4 assert len(serialized_graph1["modules"]) == 0 assert len(serialized_graph2["nodes"]) == 6 assert len(serialized_graph2["weights"]) == 4 assert len(serialized_graph2["modules"]) == 1 assert serialized_graph1["weights"]["linear.weight"][ "shape"] == "[4, 4]" assert (serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32") assert (serialized_graph1["weights"]["linear.weight"]["is_quantized"] is False) assert serialized_graph1["nodes"][0]["shape"] == "[4]" assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32" assert serialized_graph1["nodes"][0]["target"] == "a" assert serialized_graph1["nodes"][0]["op_code"] == "placeholder" assert serialized_graph1["nodes"][0]["name"] == "a" assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_2" assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True # Test quantization info serialization. x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32) q_tensor_channel = torch.quantize_per_channel( x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) result = graph_manipulation.serialize_tensor_quantization(q_tensor) result2 = graph_manipulation.serialize_tensor_quantization( q_tensor_channel) assert result["q_scheme"] == "torch.per_tensor_affine" assert result["q_scale"] == 1.0 assert result2["q_scheme"] == "torch.per_channel_affine" assert len(result2["q_per_channel_scales"]) == 2
def test_cost_aware_partition(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + torch.rand(4) add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, Nodelatency] = {} for node in fx_module.graph.nodes: if node.op not in {'output', 'placeholder', 'get_attr'}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 1) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size) return node_to_latency_mapping m = MyModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2), Device('dev_3', 125, 3) ] node_to_latency_mapping = get_node_to_latency_mapping(traced) partitioner_config = PartitionerConfig( devices, is_sparse_nn=False, is_cost_aware=True, transfer_rate_bytes_per_sec=2, node_to_latency_mapping=node_to_latency_mapping) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping) critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, partitioner_config.transfer_rate_bytes_per_sec) assert critical_path_latency_sec == 160. def test_replace_target_nodes_with(self): class testModule(torch.nn.Module): def forward(self, a, b): return a + b m = testModule() traced = symbolic_trace(m) input1 = torch.randn(1) input2 = torch.randn(1) assert (input1 + input2) == traced(input1, input2) graph_manipulation.replace_target_nodes_with( fx_module=traced, old_op="call_function", old_target=operator.add, new_op="call_function", new_target=operator.mul, ) assert (input1 * input2) == traced(input1, input2)