def test_fetch(self):
        attrs_for_lowering: Dict[str, List[str]] = {
            "torch.nn.modules.conv.Conv2d": [
                "weight", "bias", "kernel_size", "stride", "padding",
                "dilation", "groups", "padding_mode"
            ],
            "torch.nn.modules.batchnorm.BatchNorm2d":
            ["weight", "bias", "running_mean", "running_var", "eps"],
        }

        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 2)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, a):
                a = self.conv(a)
                a += a
                return self.bn(a)

        mod = TestModule()
        traced = symbolic_trace(mod)
        lift_lowering_attrs_to_nodes(traced)

        for node in traced.graph.nodes:
            if node.op == "call_module":
                assert hasattr(node, "attrs_for_lowering")
                para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]

                # node.attrs_for_lowering has an addition field of class name
                assert len(para_list) + 1 == len(node.attrs_for_lowering)
                for p_name in para_list:
                    assert p_name in node.attrs_for_lowering
    def test_subgraph_uniquename(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a, b, c, d):
                add_1 = a + b
                add_2 = add_1 + c
                linear_1 = self.linear(add_1)
                add_3 = add_2 + d
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4)
        mm = MyModule()
        traced = symbolic_trace(mm)

        def split_cb(node: torch.fx.Node):
            if node.name == 'a' or node.name == 'b' or node.name == 'add':
                return 0
            else:
                return 1

        module_with_submodule = split_module(traced, mm, split_cb)
        self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
    def test_large_node_error(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                linear = self.linear(a)
                add = linear + a
                return add

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        partitioner = Partitioner()
        devices = [
            Device("dev_0", 40, 0),
            Device("dev_1", 40, 0),
            Device("dev_2", 40, 0),
            Device("dev_3", 40, 0),
            Device("dev_4", 40, 0)
        ]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.size_based)
        catch_runtime_error = False
        try:
            ret = partitioner.partition_graph(traced, m, partitioner_config)
        except RuntimeError:
            catch_runtime_error = True
        assert catch_runtime_error
    def test_size_based_partition(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)
                self.c = torch.rand(4)

            def forward(self, a, b):
                add_1 = a + b
                linear = self.linear(add_1)
                add_2 = linear + self.c
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device("dev_0", 125, 0),
            Device("dev_1", 125, 1),
            Device("dev_2", 125, 2)
        ]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.size_based)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        for i, node in enumerate(dag.nodes):
            assert node.logical_device_ids == [i]
Example #5
0
    def test_partition_combining(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear_0 = torch.nn.Linear(4, 4)

            def forward(self, a, b):
                add_1 = a + b
                c = self.linear_0(a)
                add_2 = c + add_1
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2)
        ]
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        assert len(module_with_submodules.graph.nodes) == 5
    def test_partition_device_mapping(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                b = torch.rand(4)
                add_1 = a + b
                linear_1 = self.linear(add_1)
                add_2 = torch.rand(4) + a
                add_3 = add_2 + linear_1
                return add_3

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        partitioner = Partitioner()
        devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.size_based)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        for i, node in enumerate(dag.nodes):
            if i == 1:
                assert node.logical_device_ids == [1]
            else:
                assert node.logical_device_ids == [0]
    def test_to_folder(self):
        class Test(torch.nn.Module):
            def __init__(self):
                super(Test, self).__init__()
                self.W = torch.nn.Parameter(torch.randn(2))
                self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
                self.linear = torch.nn.Linear(2, 2)
                self.attr = torch.randn(2)
                self.register_buffer('attr2', torch.randn(2))

            def forward(self, x):
                return self.linear(
                    self.seq(self.W + self.attr + self.attr2 + x))

        mod = symbolic_trace(Test())
        module_name = 'Foo'
        import tempfile
        from pathlib import Path
        with tempfile.TemporaryDirectory() as tmp_dir:
            tmp_dir = Path(tmp_dir)
            mod.to_folder(tmp_dir, module_name)
            # Recipe taken from here:
            # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
            import importlib.util
            spec = importlib.util.spec_from_file_location(
                module_name, tmp_dir / '__init__.py')
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            spec.loader.exec_module(module)
            t = torch.randn(2, 2)
            self.assertEqual(module.Foo()(t), mod(t))
    def test_partition_node_manipulation(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                add_1 = a + b
                add_2 = add_1 + torch.rand(4)
                add_3 = add_2 + torch.rand(4)
                return add_3

        m = TestModule()
        traced = symbolic_trace(m)
        a, b = torch.rand(4), torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [Device('dev_0', 1000, 0)]
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        partition = partitioner.partitions[0]
        assert partition.used_mem_bytes == 112
        # Select add_2 node to remove
        selected_node = None
        for node in partition.nodes:
            if node.name == 'add_2':
                selected_node = node
        partition.remove_node(selected_node)
        assert (partition.used_mem_bytes == 80)
    def test_partition_combining(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                b = torch.rand(4)
                add_1 = a + b
                linear_1 = self.linear(add_1)
                add_2 = torch.rand(4) + a
                add_3 = add_2 + linear_1
                return add_3

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        partitioner = Partitioner()
        devices = [Device('dev_0', 120, 0), Device('dev_1', 144, 1)]
        partitioner_config = PartitionerConfig(devices, is_sparse_nn=False)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        assert dag.nodes[0].logical_device_ids == [0]
        assert dag.nodes[0].size_bytes == 80
        assert dag.nodes[1].logical_device_ids == [1]
        assert dag.nodes[1].size_bytes == 144
Example #10
0
    def test_size_based_partition(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a, b):
                add_1 = a + b
                linear = self.linear(add_1)
                e = torch.rand(4)
                add_2 = linear + e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device('dev_0', 125),
            Device('dev_1', 125),
            Device('dev_2', 125)
        ]
        ret = partitioner.partition_graph(traced, m, devices)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        assert len(module_with_submodules.graph.nodes) == 7
Example #11
0
    def test_sparse_nn_partition(self):
        class MyRecommendationModule(torch.nn.Module):
            def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
                layers = torch.nn.ModuleList()
                for _ in range(num_of_layers):
                    ll = torch.nn.Linear(input_size, output_size)
                    layers.append(ll)
                    layers.append(torch.nn.ReLU())
                return layers

            def __init__(self):
                super(MyRecommendationModule, self).__init__()
                layers = self.create_mlp(4, 4, 4)
                self.bottom_layers = torch.nn.Sequential(*layers)
                layers = self.create_mlp(3, 24, 24)
                self.top_layers = torch.nn.Sequential(*layers)
                self.embedding_layers = torch.nn.ModuleList()
                el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True)
                self.embedding_layers.append(el)
                for i in range(3):
                    el = torch.nn.EmbeddingBag(1000000, 4, mode='sum', sparse=True)
                    self.embedding_layers.append(el)
                el = torch.nn.EmbeddingBag(500000, 4, mode='sum', sparse=True)
                self.embedding_layers.append(el)

            def forward(self, a, b, offset):
                x = self.bottom_layers(a)
                y = []
                c = []
                for i in range(len(self.embedding_layers)):
                    temp = torch.randint(10, (8, ))
                    c.append(temp + b)
                for i in range(len(self.embedding_layers)):
                    if i % 2 == 0:
                        y.append(self.embedding_layers[i](c[i], offset))
                    else:
                        y.append(self.embedding_layers[i](torch.randint(10, (8, )), offset))
                z = torch.cat([x] + y, dim=1)
                p = self.top_layers(z)
                return p

        m = MyRecommendationModule()
        a = torch.rand(2, 4)
        b = torch.randint(10, (8, ))
        offset = torch.randint(1, (2, ))
        traced = symbolic_trace(m)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b, offset])
        devices = [
            Device('dev_0', 33000000, 0),
            Device('dev_1', 33000000, 1),
            Device('dev_2', 33000000, 2)
        ]
        partitioner_config = PartitionerConfig(devices, is_sparse_nn=True)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
        assert len(module_with_submodules.graph.nodes) == 24
 def test_subgraph_trivial_resnet(self):
     # Smoke test trivially splitting resnet into 1 partition works
     # There was an issue before causing submodule names to be aliased
     m = resnet18()
     traced = symbolic_trace(m)
     a = torch.rand(64, 3, 7, 7)
     module_with_submodules = split_module(traced, m, lambda node: 0)
     module_with_submodules(a)
Example #13
0
    def test_cost_aware_partition(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + torch.rand(4)
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            node_to_latency_mapping: Dict[Node, Nodelatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 1)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = MyModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2),
            Device('dev_3', 125, 3)
        ]
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        partitioner_config = PartitionerConfig(
            devices,
            is_sparse_nn=False,
            is_cost_aware=True,
            transfer_rate_bytes_per_sec=2,
            node_to_latency_mapping=node_to_latency_mapping)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping)
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping,
            partitioner_config.transfer_rate_bytes_per_sec)
        assert critical_path_latency_sec == 160.
Example #14
0
    def test_partition_latency(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super(TestModule, self).__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + linear_1
                add_4 = add_2 + add_3
                return add_4

        def get_node_to_latency_mapping(fx_module: GraphModule):
            """Given a fx module, generate node latency for each node
            based on the size of each node
            """
            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {"output", "placeholder", "get_attr"}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
                        )
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, node.size_bytes.output_size
                        )
            return node_to_latency_mapping

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
        partitioner = Partitioner()
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping
        )
        for p in partition_to_latency_mapping:
            if p.partition_id == 0:
                assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
            else:
                assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
        transfer_rate_bytes_per_sec = 2
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
        )
        assert critical_path_latency_sec == 208.0
    def test_normalize_binary_operators(self):
        ops_to_test = {
            torch.add,
            torch.mul,
            torch.sub,
            torch.div,
            torch.floor_divide,
            torch.remainder,
            torch.eq,
            torch.ne,
            torch.lt,
            torch.le,
            torch.gt,
            torch.ge,
        }

        # Test Tensor/Tensor callsite
        for op in ops_to_test:

            class WrapperMod(torch.nn.Module):
                def forward(self, x, y):
                    return op(x, y)

            traced = symbolic_trace(WrapperMod())
            normalized = NormalizeOperators(traced).transform()
            x, y = torch.randn(3, 4), torch.randn(3, 4)
            torch.testing.assert_allclose(traced(x, y), normalized(x, y))
            self.assertFalse(
                any(n.target in ops_to_test for n in normalized.graph.nodes))

        # Test Tensor/scalar callsite
        for op in ops_to_test:

            class WrapperMod(torch.nn.Module):
                def forward(self, x):
                    return op(x, 42)

            traced = symbolic_trace(WrapperMod())
            normalized = NormalizeOperators(traced).transform()
            x = torch.randn(3, 4)
            torch.testing.assert_allclose(traced(x), normalized(x))
            self.assertFalse(
                any(n.target in ops_to_test for n in normalized.graph.nodes))
Example #16
0
    def test_conv_bn_fusion(self):
        rn18 = resnet18().eval()
        traced = symbolic_trace(rn18)
        fused = fuse(traced)

        self.assertTrue(all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()))

        N, C, H, W = 20, 3, 224, 224
        inp = torch.randn(N, C, H, W)

        self.assertEqual(fused(inp), rn18(inp))
    def test_subgraph_creation(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(3, 4))
                self.linear = torch.nn.Linear(4, 5)

            def forward(self, x, y):
                z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
                w = self.linear(y).clamp(min=0.0, max=1.0)
                return z + w

        # symbolically trace model
        my_module = MyModule()
        my_module_traced = symbolic_trace(my_module)

        # random mod partitioning
        partition_counter = 0
        NPARTITIONS = 3

        # Add some random meta info to make sure it is kept around.
        for node in my_module_traced.graph.nodes:
            if node.op != "output":
                node.meta["test_meta_info"] = True

        def mod_partition(node: Node):
            nonlocal partition_counter
            partition = partition_counter % NPARTITIONS
            partition_counter = (partition_counter + 1) % NPARTITIONS
            return partition

        # split module in module with submodules
        module_with_submodules = split_module(my_module_traced, my_module,
                                              mod_partition)

        # Check that test_meta_info was still on all nodes.
        submodules = dict(module_with_submodules.named_modules())
        for node in module_with_submodules.graph.nodes:
            if node.op == "call_module":
                submod = submodules[node.target]
                self.assertTrue(isinstance(submod, torch.fx.GraphModule))
                for submod_node in submod.graph.nodes:
                    if submod_node.op != "output":
                        stored_op = submod_node.meta.get("test_meta_info")
                        self.assertTrue(stored_op is not None and stored_op)

        x = torch.rand(3, 4)
        y = torch.rand(3, 4)

        orig_out = my_module_traced(x, y)
        submodules_out = module_with_submodules(x, y)

        self.assertEqual(orig_out, submodules_out)
Example #18
0
        def test_kl_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.linear = torch.nn.Linear(4, 4)
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)
                    self.d = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = add_1 + self.c
                    linear_1 = self.linear(add_1)
                    add_3 = add_2 + linear_1
                    add_4 = add_2 + self.d
                    add_5 = add_3 + add_4
                    return add_4
            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            graph_manipulation.get_size_of_all_nodes(traced, [a])
            node_to_latency_mapping = get_node_to_latency_mapping(traced)
            transfer_rate_bytes_per_sec = 2
            devices = [
                Device('dev_0', 200, 0),
                Device('dev_1', 200, 1),
                Device('dev_2', 200, 2),
                Device('dev_3', 200, 3)
            ]
            partitioner = Partitioner()
            partitioner_config = PartitionerConfig(
                devices,
                mode=PartitionMode.kl_based,
                transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec,
                node_to_latency_mapping=node_to_latency_mapping
            )
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            self.assertEqual(traced(a), module_with_submodules(a))
            dag = ret.dag
            assert dag.nodes[0] == 176
            assert dag.nodes[1] == 112
            partition_to_latency_mapping = get_partition_to_latency_mapping(
                partitioner.partitions,
                node_to_latency_mapping
            )
            cost = get_latency_of_partitioned_graph(
                partitioner.partitions,
                partition_to_latency_mapping,
                transfer_rate_bytes_per_sec
            )
            assert cost == 208.
Example #19
0
    def test_partition_latency(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super(TestModule, self).__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_4 = add_2 + linear_1
                add_5 = add_2 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            """Given a fx module, generate node latency for each node
               based on the size of each node
            """
            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            2. * node.size_bytes.total_size)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        GraphManipulation.get_size_of_all_nodes(traced, [a])
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        devices = [Device('dev_0', 200, 0), Device('dev_1', 200, 0)]
        partitioner = Partitioner()
        partitioner_config = PartitionerConfig(devices, False)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_latency_0 = get_latency_of_one_partition(
            partitions[0], node_to_latency_mapping)
        assert (128., 80., 160.) == partition_latency_0
        partition_latency_1 = get_latency_of_one_partition(
            partitions[1], node_to_latency_mapping)
        assert (16., 32., 32) == partition_latency_1
    def test_annotate_returns_with_schema(self):
        m = resnet18()

        traced_modules = symbolic_trace(m)
        traced_modules_annotated = AnnotateTypesWithSchema(
            traced_modules).transform()
        for node in traced_modules_annotated.graph.nodes:
            if node.type is None:
                check = (node.op, node.target)
                self.assertTrue(
                    check in {('placeholder',
                               'x'), ('call_function', operator.add
                                      ), ('call_function',
                                          torch.flatten), ('output',
                                                           'output')})

        # Smoke test torchscript compilation since now we're emitting type annotations
        torch.jit.script(traced_modules_annotated)

        class FunctionalTracer(torch.fx.Tracer):
            def is_leaf_module(self, m: torch.nn.Module,
                               module_qualified_name: str) -> bool:
                # `leaves` contains the set of standard `nn.Modules` that are not
                # currently symbolically traceable. Ideally this set would be empty
                leaves = set([torch.nn.BatchNorm2d])
                return type(m) in leaves

        traced_functionals = torch.fx.GraphModule(m,
                                                  FunctionalTracer().trace(m))

        traced_functionals_annotated = AnnotateTypesWithSchema(
            traced_functionals).transform()
        for node in traced_functionals_annotated.graph.nodes:
            if node.type is None:
                check = (node.op, node.target)
                excluded_nodes = {
                    ('placeholder', 'x'),
                    ('call_function', torch.conv2d),
                    # Return type differs based on boolean dispatch :(
                    ('call_function', torch.nn.functional.max_pool2d),
                    ('call_function', operator.add),
                    ('call_function', torch.flatten),
                    ('output', 'output'),
                }
                self.assertTrue(check in excluded_nodes)

        # Smoke test torchscript compilation since now we're emitting type annotations
        torch.jit.script(traced_functionals_annotated)
Example #21
0
 def test_replace_target_nodes_with(self):
     class testModule(torch.nn.Module):
         def forward(self, a, b):
             return a + b
     m = testModule()
     traced = symbolic_trace(m)
     input1 = torch.randn(1)
     input2 = torch.randn(1)
     assert (input1 + input2) == traced(input1, input2)
     graph_manipulation.replace_target_nodes_with(
         fx_module=traced,
         old_op="call_function",
         old_target=operator.add,
         new_op="call_function",
         new_target=operator.mul,
     )
     assert (input1 * input2) == traced(input1, input2)
    def test_lack_of_devices(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)]
        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
        catch_runtime_error = False
        try:
            ret = partitioner.partition_graph(traced, m, partitioner_config)
        except RuntimeError:
            catch_runtime_error = True
        assert catch_runtime_error
Example #23
0
    def test_find_single_partition(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(1)
        b = torch.rand(1)
        GraphManipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device('dev_0', 125),
            Device('dev_1', 125),
            Device('dev_2', 125)
        ]
        ret = partitioner.partition_graph(traced, m, devices)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
    def test_find_single_partition(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(1)
        b = torch.rand(1)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device("dev_0", 125, 0),
            Device("dev_1", 125, 1),
            Device("dev_2", 125, 2)
        ]
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        assert dag.nodes[0].logical_device_ids == [0]
        def test_aot_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = self.c + add_1
                    return add_2

            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            node_to_partition_id = {}
            partition_to_logical_devices = {}
            count = 0
            GraphManipulation.get_size_of_all_nodes(traced, [a])
            for node in traced.graph.nodes:
                if node.op not in {'placeholder', 'get_attr', 'output'}:
                    node_to_partition_id[node] = count
                    partition_to_logical_devices[count] = [0]
                    count += 1
            devices = [Device('dev_0', 200, 0)]
            partitioner_config = PartitionerConfig(
                devices=devices,
                mode=PartitionMode.aot_based,
                node_to_partition_mapping=node_to_partition_id,
                partition_to_logical_device_mapping=partition_to_logical_devices
            )
            partitioner = Partitioner()
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            dag = ret.dag
            self.assertEqual(module_with_submodules(a), traced(a))
            for node in dag.nodes:
                assert node.size_bytes == 48
                assert node.logical_device_ids == [0]
Example #26
0
    def test_subgraph_creation(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(3, 4))
                self.linear = torch.nn.Linear(4, 5)

            def forward(self, x, y):
                z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
                w = self.linear(y).clamp(min=0.0, max=1.0)
                return z + w

        # symbolically trace model
        my_module = MyModule()
        my_module_traced = symbolic_trace(my_module)

        # random mod partitioning
        partition_counter = 0
        NPARTITIONS = 3

        def mod_partition(node: Node):
            nonlocal partition_counter
            partition = partition_counter % NPARTITIONS
            partition_counter = (partition_counter + 1) % NPARTITIONS
            return partition

        # split module in module with submodules
        module_with_submodules = split_module(my_module_traced, my_module,
                                              mod_partition)

        x = torch.rand(3, 4)
        y = torch.rand(3, 4)

        orig_out = my_module_traced(x, y)
        submodules_out = module_with_submodules(x, y)

        self.assertEqual(orig_out, submodules_out)
    def test_normalize_modules_exhaustive(self):
        """
        Exhaustively test `NormalizeArgs` on all standard
        torch.nn Module classes
        """
        for test_params in module_tests + new_module_tests:
            if 'constructor' not in test_params:
                constructor = getattr(torch.nn, test_params['module_name'])
            else:
                constructor = test_params['constructor']

            if 'constructor_args' not in test_params:
                args = ()
            else:
                args = test_params['constructor_args']

            mod = constructor(*args)
            # Skip modules that are not standard `torch.nn`
            # instances, including functionals. (functionals
            # are tested in test_normalize_args)
            if mod.__class__.__name__ not in dir(torch.nn):
                continue

            if 'input_fn' not in test_params:
                inputs = torch.randn(test_params['input_size'])
            else:
                inputs = test_params['input_fn']()

            if not isinstance(inputs, (tuple, list)):
                inputs = (inputs, )

            params = ', '.join(f'v{i}' for i in range(len(inputs)))

            # Generate a class to wrap this standard `nn.Module` instance
            test_classname = f'Test{mod.__class__.__name__}'
            test_mod_code = f"""
class {test_classname}(torch.nn.Module):
    def __init__(self, mod):
        super().__init__()
        self.mod = mod

    def forward(self, {params}):
        return self.mod({params})
            """

            gbls = {'torch': torch}
            exec(test_mod_code, gbls)

            test_instance = gbls[test_classname](mod)
            traced = symbolic_trace(test_instance)

            # Now actually test arg normalization!
            traced = NormalizeArgs(traced).transform()

            # These Modules have an RNG in their forward, so testing
            # correctness by comparing outputs is not correct. Skip that
            # check for these
            stochastic_modules = {
                'FractionalMaxPool2d', 'FractionalMaxPool3d', 'RReLU'
            }

            if mod.__class__.__name__ not in stochastic_modules:
                self.assertEqual(traced(*inputs), mod(*inputs))

            # Ensure all args/kwargs are normalized into kwargs
            modules = dict(traced.named_modules())
            for node in traced.graph.nodes:
                if node.op == 'call_module':
                    submod_class = modules[node.target].__class__
                    nn_class = getattr(torch.nn, submod_class.__name__)
                    if submod_class == nn_class:
                        self.assertEqual(len(node.args), 0)
Example #28
0
def merge_matmul(in_mod: torch.nn.Module):
    """
    A graph transformation that merges matrix multiplication operations that share the same right-hand
    side operand into one large matrix multiplication.
               ____      _________        _________
      ----    |    |    |         |     M|  A * C  |
    M| A  |  T| B  | * K|    C    | =    |---------|
      ---- ,  |    |    |         |     T|  B * C  |
       K       ----      ---------        ---------
                K            R                R
    """
    gm = symbolic_trace(in_mod)

    rhs_users: Dict[Node, List[Node]] = {}
    lhs_users: Dict[Node, List[Node]] = {}

    # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
    # the matmul of which they are the LHS/RHS.
    for node in gm.graph.nodes:
        if node.op != "call_function" or node.target is not torch.matmul:
            continue

        lhs, rhs = node.args

        # TODO: Properly handle aliasing caused by get_attr. For now,
        # use the attribute name as the operand if the node is a
        # get_attr.
        lhs = lhs.target if lhs.op == "get_attr" else lhs
        rhs = rhs.target if rhs.op == "get_attr" else rhs

        lhs_users.setdefault(lhs, []).append(node)
        rhs_users.setdefault(rhs, []).append(node)

    for rhs, mms in rhs_users.items():
        # There must be at least matmuls for a merge to make sense.
        if len(mms) < 2:
            continue

        # All matmuls must not depend on each other directly or indirectly
        # in order for the merge to be possible.
        if not are_nodes_independent(mms):
            continue

        lhs_vals = [mm.args[0] for mm in mms]

        # Merge the matmul.
        # Collect a list of LHS operands and the single RHS operand.
        lhs = [
            gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals
        ]
        rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs

        # Concatenate all the LHS operands.
        merge_mm_cat = gm.graph.call_function(torch.cat, (lhs, ), {})

        # Multiply the concatenated LHS operands with the one RHS. This will produce
        # the same results as all the individual matmuls involving rhs in the original graph,
        # but they will all be concatenated together.
        merge_mm = gm.graph.call_function(torch.matmul, (
            merge_mm_cat,
            rhs,
        ), {})

        # Split the result of the merged matmul using the shapes of the LHS operands
        # to ascertain how large each chunk should be.
        merge_mm_sizes = [
            gm.graph.call_function(get_first_dim, (l, ), {}) for l in lhs
        ]
        merge_mm_split = gm.graph.call_function(torch.split,
                                                (merge_mm, merge_mm_sizes), {})
        merge_mm_res = [
            gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
            for out in range(len(lhs))
        ]

        # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
        for old, new in zip(mms, merge_mm_res):
            old.replace_all_uses_with(new)
            gm.graph.erase_node(old)

        # All of the new nodes created above were inserted at the end, so we need to sort
        # the nodes topologically to make sure all definitions precede uses.
        legalize_graph(gm)

    gm.recompile()
    gm.graph.lint(in_mod)
    return gm
    def test_serialize_graph(self):
        class TestModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)
                self.e = torch.rand(4)
                self.conv = torch.nn.Conv2d(3, 3, 2, bias=False)

            def forward(self, a, b, c):
                add_1 = a + b
                conv1 = self.conv(c)
                linear = self.linear(add_1 + conv1)
                add_2 = linear + self.e
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        c = torch.rand(3, 3, 2, 2)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b, c])

        partitioner = Partitioner()
        devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)]
        partitioner_config = PartitionerConfig(devices,
                                               PartitionMode.sparse_nn)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        # Fix for now to add type/shape to output
        for node in traced.graph.nodes:
            if node.op == "output":
                node.meta['tensor_meta'] = extract_tensor_metadata(a)
        for mod in module_with_submodules.modules():
            if isinstance(mod, GraphModule):
                for node in mod.graph.nodes:
                    node.meta['tensor_meta'] = extract_tensor_metadata(a)
        for node in module_with_submodules.graph.nodes:
            node.meta['tensor_meta'] = extract_tensor_metadata(a)

        weights1 = {}
        weights2 = {}
        serialized_graph1 = graph_manipulation.serialize_module(
            traced, weights1)
        serialized_graph2 = graph_manipulation.serialize_module(
            module_with_submodules, weights2)
        assert len(weights1) == 4
        assert len(weights2) == 4
        assert len(serialized_graph1["nodes"]) == 10
        assert len(serialized_graph1["weights"]) == 4
        assert len(serialized_graph1["modules"]) == 0
        assert len(serialized_graph2["nodes"]) == 6
        assert len(serialized_graph2["weights"]) == 4
        assert len(serialized_graph2["modules"]) == 1
        assert serialized_graph1["weights"]["linear.weight"][
            "shape"] == "[4, 4]"
        assert (serialized_graph1["weights"]["linear.weight"]["dtype"] ==
                "torch.float32")
        assert (serialized_graph1["weights"]["linear.weight"]["is_quantized"]
                is False)
        assert serialized_graph1["nodes"][0]["shape"] == "[4]"
        assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32"
        assert serialized_graph1["nodes"][0]["target"] == "a"
        assert serialized_graph1["nodes"][0]["op_code"] == "placeholder"
        assert serialized_graph1["nodes"][0]["name"] == "a"
        assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_1"
        assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True

        # Test quantization info serialization.
        x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
        q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32)
        q_tensor_channel = torch.quantize_per_channel(
            x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0,
            torch.quint8)
        result = graph_manipulation.serialize_tensor_quantization(q_tensor)
        result2 = graph_manipulation.serialize_tensor_quantization(
            q_tensor_channel)
        assert result["qscheme"] == "torch.per_tensor_affine"
        assert result["q_scale"] == 1.0
        assert result2["qscheme"] == "torch.per_channel_affine"
        assert len(result2["q_per_channel_scales"]) == 2
    def test_cost_aware_partition(self):
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + torch.rand(4)
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            node_to_latency_mapping: Dict[Node, Nodelatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {'output', 'placeholder', 'get_attr'}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 1)
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size,
                            node.size_bytes.output_size)
            return node_to_latency_mapping

        m = MyModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device('dev_0', 125, 0),
            Device('dev_1', 125, 1),
            Device('dev_2', 125, 2),
            Device('dev_3', 125, 3)
        ]
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        partitioner_config = PartitionerConfig(
            devices,
            mode=PartitionMode.cost_aware,
            transfer_rate_bytes_per_sec=2,
            node_to_latency_mapping=node_to_latency_mapping)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping)
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping,
            partitioner_config.transfer_rate_bytes_per_sec)
        assert critical_path_latency_sec == 160.

        def test_kl_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.linear = torch.nn.Linear(4, 4)
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)
                    self.d = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = add_1 + self.c
                    linear_1 = self.linear(add_1)
                    add_3 = add_2 + linear_1
                    add_4 = add_2 + self.d
                    add_5 = add_3 + add_4
                    return add_4

            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            graph_manipulation.get_size_of_all_nodes(traced, [a])
            node_to_latency_mapping = get_node_to_latency_mapping(traced)
            transfer_rate_bytes_per_sec = 2
            devices = [
                Device('dev_0', 200, 0),
                Device('dev_1', 200, 1),
                Device('dev_2', 200, 2),
                Device('dev_3', 200, 3)
            ]
            partitioner = Partitioner()
            partitioner_config = PartitionerConfig(
                devices,
                mode=PartitionMode.kl_based,
                transfer_rate_bytes_per_sec=transfer_rate_bytes_per_sec,
                node_to_latency_mapping=node_to_latency_mapping)
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            self.assertEqual(traced(a), module_with_submodules(a))
            dag = ret.dag
            assert dag.nodes[0] == 176
            assert dag.nodes[1] == 112
            partition_to_latency_mapping = get_partition_to_latency_mapping(
                partitioner.partitions, node_to_latency_mapping)
            cost = get_latency_of_partitioned_graph(
                partitioner.partitions, partition_to_latency_mapping,
                transfer_rate_bytes_per_sec)
            assert cost == 208.

        def test_aot_based_partition(self):
            class TestModule(torch.nn.Module):
                def __init__(self):
                    super(TestModule, self).__init__()
                    self.b = torch.rand(4)
                    self.c = torch.rand(4)

                def forward(self, a):
                    add_1 = a + self.b
                    add_2 = self.c + add_1
                    return add_2

            m = TestModule()
            traced = symbolic_trace(m)
            a = torch.rand(4)
            node_to_partition_id = {}
            partition_to_logical_devices = {}
            count = 0
            GraphManipulation.get_size_of_all_nodes(traced, [a])
            for node in traced.graph.nodes:
                if node.op not in {'placeholder', 'get_attr', 'output'}:
                    node_to_partition_id[node] = count
                    partition_to_logical_devices[count] = [0]
                    count += 1
            devices = [Device('dev_0', 200, 0)]
            partitioner_config = PartitionerConfig(
                devices=devices,
                mode=PartitionMode.aot_based,
                node_to_partition_mapping=node_to_partition_id,
                partition_to_logical_device_mapping=partition_to_logical_devices
            )
            partitioner = Partitioner()
            ret = partitioner.partition_graph(traced, m, partitioner_config)
            module_with_submodules = ret.module_with_submodules
            dag = ret.dag
            self.assertEqual(module_with_submodules(a), traced(a))
            for node in dag.nodes:
                assert node.size_bytes == 48
                assert node.logical_device_ids == [0]

        def test_replace_target_nodes_with(self):
            class testModule(torch.nn.Module):
                def forward(self, a, b):
                    return a + b

            m = testModule()
            traced = symbolic_trace(m)
            input1 = torch.randn(1)
            input2 = torch.randn(1)
            assert (input1 + input2) == traced(input1, input2)
            graph_manipulation.replace_target_nodes_with(
                fx_module=traced,
                old_op="call_function",
                old_target=operator.add,
                new_op="call_function",
                new_target=operator.mul,
            )
            assert (input1 * input2) == traced(input1, input2)