Exemplo n.º 1
0
def test_fuse_modules_with_pre_exist_adj_map():
    model = WrappedSequential(DummyA(), DummyB(), DummyD())
    with pytest.raises(ValueError):
        mt.fuse_modules(model,
                        types_sequence,
                        fuse_fn,
                        dummy_input=None,
                        adjacency_map=None)

    dummy_input = torch.randn(10, 10)
    sg = SummaryGraph(deepcopy(model), dummy_input)
    adj_map = sg.adjacency_map()

    fused_dummy_input = mt.fuse_modules(deepcopy(model),
                                        types_sequence,
                                        fuse_fn,
                                        dummy_input=dummy_input,
                                        adjacency_map=None)
    compare_models(fused_dummy_input, fused_reference())

    fused_pre_sg = mt.fuse_modules(deepcopy(model),
                                   types_sequence,
                                   fuse_fn,
                                   dummy_input=None,
                                   adjacency_map=adj_map)
    compare_models(fused_pre_sg, fused_reference())

    fused_both = mt.fuse_modules(deepcopy(model),
                                 types_sequence,
                                 fuse_fn,
                                 dummy_input=dummy_input,
                                 adjacency_map=adj_map)
    compare_models(fused_both, fused_reference())
Exemplo n.º 2
0
def test_adjacency_map(parallel, dedicated_modules):
    class TestModel(nn.Module):
        def __init__(self):
            super(TestModel, self).__init__()
            self.conv = nn.Conv2d(3, 10, 5)
            self.bn = nn.BatchNorm2d(10)
            self.post_conv_bn = nn.ModuleList([nn.Tanh(), nn.ReLU()])

        def forward(self, x):
            res = self.conv(x)
            y = self.bn(res)
            for m in self.post_conv_bn:
                y = m(y)
            return y + res

    def check_adj_entry(actual, expected):
        assert actual.op_meta == expected.op_meta
        assert actual.predecessors == expected.predecessors
        assert actual.successors == expected.successors

    prefix = 'module.' if parallel else ''

    m = TestModel()
    if parallel:
        m = nn.DataParallel(m)
    sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 3, 10, 10)))
    adj_map = sg.adjacency_map(dedicated_modules_only=dedicated_modules)

    if dedicated_modules:
        assert len(adj_map) == 4
    else:
        assert len(adj_map) == 5

    conv_op_meta = OpSimpleMetadata(prefix + 'conv', 'Conv')
    bn_op_meta = OpSimpleMetadata(prefix + 'bn', 'BatchNormalization')
    tanh_op_meta = OpSimpleMetadata(prefix + 'post_conv_bn.0', 'Tanh')
    relu_op_meta = OpSimpleMetadata(prefix + 'post_conv_bn.1', 'Relu')
    add_op_meta = OpSimpleMetadata('top_level_op', 'Add')

    name = conv_op_meta.name
    assert name in adj_map
    expected = AdjacentsEntry(conv_op_meta)
    expected.successors = [bn_op_meta] if dedicated_modules else [
        bn_op_meta, add_op_meta
    ]
    check_adj_entry(adj_map[name], expected)

    name = bn_op_meta.name
    assert name in adj_map
    expected = AdjacentsEntry(bn_op_meta)
    expected.predecessors = [conv_op_meta]
    expected.successors = [tanh_op_meta]
    check_adj_entry(adj_map[name], expected)

    name = tanh_op_meta.name
    assert name in adj_map
    expected = AdjacentsEntry(tanh_op_meta)
    expected.predecessors = [bn_op_meta]
    expected.successors = [relu_op_meta]
    check_adj_entry(adj_map[name], expected)

    name = relu_op_meta.name
    assert name in adj_map
    expected = AdjacentsEntry(relu_op_meta)
    expected.predecessors = [tanh_op_meta]
    expected.successors = [] if dedicated_modules else [add_op_meta]
    check_adj_entry(adj_map[name], expected)

    name = add_op_meta.name
    if dedicated_modules:
        assert name not in adj_map
    else:
        assert name in adj_map
        expected = AdjacentsEntry(add_op_meta)
        expected.predecessors = [relu_op_meta, conv_op_meta]
        check_adj_entry(adj_map[name], expected)