def test_partition_combining(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): b = torch.rand(4) add_1 = a + b linear_1 = self.linear(add_1) add_2 = torch.rand(4) + a add_3 = add_2 + linear_1 return add_3 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a]) partitioner = Partitioner() devices = [Device('dev_0', 120, 0), Device('dev_1', 144, 1)] partitioner_config = PartitionerConfig(devices, is_sparse_nn=False) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) assert dag.nodes[0].logical_device_ids == [0] assert dag.nodes[0].size_bytes == 80 assert dag.nodes[1].logical_device_ids == [1] assert dag.nodes[1].size_bytes == 144
def test_partition_combining(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear_0 = torch.nn.Linear(4, 4) def forward(self, a, b): add_1 = a + b c = self.linear_0(a) add_2 = c + add_1 return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2) ] partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) assert len(module_with_submodules.graph.nodes) == 5
def test_size_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a, b): add_1 = a + b linear = self.linear(add_1) e = torch.rand(4) add_2 = linear + e return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2) ] partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) for i, node in enumerate(dag.nodes): assert node.logical_device_ids == [i]
def test_partition_device_mapping(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): b = torch.rand(4) add_1 = a + b linear_1 = self.linear(add_1) add_2 = torch.rand(4) + a add_3 = add_2 + linear_1 return add_3 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a]) partitioner = Partitioner() devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] partitioner_config = PartitionerConfig(devices, is_sparse_nn=False) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) for i, node in enumerate(dag.nodes): if i == 1: assert node.logical_device_ids == [1] else: assert node.logical_device_ids == [0]
def test_size_based_partition(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a, b): add_1 = a + b linear = self.linear(add_1) e = torch.rand(4) add_2 = linear + e return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device('dev_0', 125), Device('dev_1', 125), Device('dev_2', 125) ] ret = partitioner.partition_graph(traced, m, devices) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a, b), module_with_submodules(a, b)) assert len(module_with_submodules.graph.nodes) == 7
def test_sparse_nn_partition(self): class MyRecommendationModule(torch.nn.Module): def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): layers = torch.nn.ModuleList() for _ in range(num_of_layers): ll = torch.nn.Linear(input_size, output_size) layers.append(ll) layers.append(torch.nn.ReLU()) return layers def __init__(self): super(MyRecommendationModule, self).__init__() layers = self.create_mlp(4, 4, 4) self.bottom_layers = torch.nn.Sequential(*layers) layers = self.create_mlp(3, 24, 24) self.top_layers = torch.nn.Sequential(*layers) self.embedding_layers = torch.nn.ModuleList() el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True) self.embedding_layers.append(el) for i in range(3): el = torch.nn.EmbeddingBag(1000000, 4, mode='sum', sparse=True) self.embedding_layers.append(el) el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True) self.embedding_layers.append(el) def forward(self, a, b, offset): x = self.bottom_layers(a) y = [] c = [] for i in range(len(self.embedding_layers)): temp = torch.randint(10, (8, )) c.append(temp + b) for i in range(len(self.embedding_layers)): if i % 2 == 0: y.append(self.embedding_layers[i](c[i], offset)) else: y.append(self.embedding_layers[i](torch.randint(10, (8, )), offset)) z = torch.cat([x] + y, dim=1) p = self.top_layers(z) return p m = MyRecommendationModule() a = torch.rand(2, 4) b = torch.randint(10, (8, )) offset = torch.randint(1, (2, )) traced = symbolic_trace(m) GraphManipulation.get_size_of_all_nodes(traced, [a, b, offset]) devices = [ Device('dev_0', 33000000, 0), Device('dev_1', 33000000, 1), Device('dev_2', 33000000, 2) ] partitioner_config = PartitionerConfig(devices, is_sparse_nn=True) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) assert len(module_with_submodules.graph.nodes) == 24
def test_partition_latency(self): class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_4 = add_2 + linear_1 add_5 = add_2 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): """Given a fx module, generate node latency for each node based on the size of each node """ node_to_latency_mapping: Dict[Node, NodeLatency] = {} for node in fx_module.graph.nodes: if node.op not in {'output', 'placeholder', 'get_attr'}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, 2. * node.size_bytes.total_size) else: node_to_latency_mapping[node] = NodeLatency(node.size_bytes.total_size, node.size_bytes.output_size) return node_to_latency_mapping m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) devices = [ Device('dev_0', 200, 0), Device('dev_1', 200, 0) ] partitioner = Partitioner() partitioner_config = PartitionerConfig(devices, False) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping(partitions, node_to_latency_mapping) for p in partition_to_latency_mapping: if p.partition_id == 0: assert partition_to_latency_mapping[p] == (128., 80., 160.) else: assert partition_to_latency_mapping[p] == (16., 32., 32.) transfer_rate_bytes_per_sec = 0.5 critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec ) assert critical_path_latency_sec == 208.
def test_find_single_partition(self): class TestModule(torch.nn.Module): def forward(self, a, b): return a + b m = TestModule() traced = symbolic_trace(m) a = torch.rand(1) b = torch.rand(1) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device('dev_0', 125), Device('dev_1', 125), Device('dev_2', 125) ] ret = partitioner.partition_graph(traced, m, devices) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a, b), module_with_submodules(a, b))
def test_find_single_partition(self): class TestModule(torch.nn.Module): def forward(self, a, b): return a + b m = TestModule() traced = symbolic_trace(m) a = torch.rand(1) b = torch.rand(1) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device('dev_0', 125, 0), Device('dev_1', 125, 1), Device('dev_2', 125, 2) ] partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) assert dag.nodes[0].logical_device_ids == [0]
def test_serialize_graph(self): class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(4, 4) self.e = torch.rand(4) def forward(self, a, b): add_1 = a + b linear = self.linear(add_1) add_2 = linear + self.e return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) GraphManipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)] partitioner_config = PartitionerConfig(devices, is_sparse_nn=True) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules # Fix for now to add type/shape to output for node in traced.graph.nodes: if node.op == "output": node.shape = a.shape node.dtype = a.dtype for mod in module_with_submodules.modules(): if isinstance(mod, GraphModule): for node in mod.graph.nodes: node.shape = a.shape node.dtype = a.dtype for node in module_with_submodules.graph.nodes: node.shape = a.shape node.dtype = a.dtype agm1 = GraphManipulation.AcceleratedGraphModule(traced) agm2 = GraphManipulation.AcceleratedGraphModule(module_with_submodules) assert len(agm1.weights) == 3 assert len(agm2.weights) == 3 assert len(agm1.serialized_graph["nodes"]) == 7 assert len(agm1.serialized_graph["weights"]) == 3 assert len(agm1.serialized_graph["modules"]) == 0 assert len(agm2.serialized_graph["nodes"]) == 5 assert len(agm2.serialized_graph["weights"]) == 3 assert len(agm2.serialized_graph["modules"]) == 1 assert agm1.serialized_graph["weights"]["linear.weight"][ "shape"] == "[4, 4]" assert (agm1.serialized_graph["weights"]["linear.weight"]["dtype"] == "torch.float32") assert ( agm1.serialized_graph["weights"]["linear.weight"]["is_quantized"] is False) assert agm1.serialized_graph["nodes"][0]["shape"] == "[4]" assert agm1.serialized_graph["nodes"][0]["dtype"] == "torch.float32" assert agm1.serialized_graph["nodes"][0]["target"] == "a" assert agm1.serialized_graph["nodes"][0]["op_code"] == "placeholder" assert agm1.serialized_graph["nodes"][0]["name"] == "a" assert agm1.serialized_graph["nodes"][2]["args"][0]["name"] == "a" assert agm1.serialized_graph["nodes"][2]["args"][0]["is_node"] is True # Test quantization info serialization. x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]]) q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32) q_tensor_channel = torch.quantize_per_channel( x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8) result = GraphManipulation.serialize_tensor_quantization(q_tensor) result2 = GraphManipulation.serialize_tensor_quantization( q_tensor_channel) assert result["q_scheme"] == "torch.per_tensor_affine" assert result["q_scale"] == 1.0 assert result2["q_scheme"] == "torch.per_channel_affine" assert len(result2["q_per_channel_scales"]) == 2