def test_fuse_modules_with_pre_exist_adj_map(): model = WrappedSequential(DummyA(), DummyB(), DummyD()) with pytest.raises(ValueError): mt.fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map=None) dummy_input = torch.randn(10, 10) sg = SummaryGraph(deepcopy(model), dummy_input) adj_map = sg.adjacency_map() fused_dummy_input = mt.fuse_modules(deepcopy(model), types_sequence, fuse_fn, dummy_input=dummy_input, adjacency_map=None) compare_models(fused_dummy_input, fused_reference()) fused_pre_sg = mt.fuse_modules(deepcopy(model), types_sequence, fuse_fn, dummy_input=None, adjacency_map=adj_map) compare_models(fused_pre_sg, fused_reference()) fused_both = mt.fuse_modules(deepcopy(model), types_sequence, fuse_fn, dummy_input=dummy_input, adjacency_map=adj_map) compare_models(fused_both, fused_reference())
def test_adjacency_map(parallel, dedicated_modules): class TestModel(nn.Module): def __init__(self): super(TestModel, self).__init__() self.conv = nn.Conv2d(3, 10, 5) self.bn = nn.BatchNorm2d(10) self.post_conv_bn = nn.ModuleList([nn.Tanh(), nn.ReLU()]) def forward(self, x): res = self.conv(x) y = self.bn(res) for m in self.post_conv_bn: y = m(y) return y + res def check_adj_entry(actual, expected): assert actual.op_meta == expected.op_meta assert actual.predecessors == expected.predecessors assert actual.successors == expected.successors prefix = 'module.' if parallel else '' m = TestModel() if parallel: m = nn.DataParallel(m) sg = SummaryGraph(m, distiller.get_dummy_input(input_shape=(1, 3, 10, 10))) adj_map = sg.adjacency_map(dedicated_modules_only=dedicated_modules) if dedicated_modules: assert len(adj_map) == 4 else: assert len(adj_map) == 5 conv_op_meta = OpSimpleMetadata(prefix + 'conv', 'Conv') bn_op_meta = OpSimpleMetadata(prefix + 'bn', 'BatchNormalization') tanh_op_meta = OpSimpleMetadata(prefix + 'post_conv_bn.0', 'Tanh') relu_op_meta = OpSimpleMetadata(prefix + 'post_conv_bn.1', 'Relu') add_op_meta = OpSimpleMetadata('top_level_op', 'Add') name = conv_op_meta.name assert name in adj_map expected = AdjacentsEntry(conv_op_meta) expected.successors = [bn_op_meta] if dedicated_modules else [ bn_op_meta, add_op_meta ] check_adj_entry(adj_map[name], expected) name = bn_op_meta.name assert name in adj_map expected = AdjacentsEntry(bn_op_meta) expected.predecessors = [conv_op_meta] expected.successors = [tanh_op_meta] check_adj_entry(adj_map[name], expected) name = tanh_op_meta.name assert name in adj_map expected = AdjacentsEntry(tanh_op_meta) expected.predecessors = [bn_op_meta] expected.successors = [relu_op_meta] check_adj_entry(adj_map[name], expected) name = relu_op_meta.name assert name in adj_map expected = AdjacentsEntry(relu_op_meta) expected.predecessors = [tanh_op_meta] expected.successors = [] if dedicated_modules else [add_op_meta] check_adj_entry(adj_map[name], expected) name = add_op_meta.name if dedicated_modules: assert name not in adj_map else: assert name in adj_map expected = AdjacentsEntry(add_op_meta) expected.predecessors = [relu_op_meta, conv_op_meta] check_adj_entry(adj_map[name], expected)