def test_fuse_modules_with_pre_exist_adj_map(): model = WrappedSequential(DummyA(), DummyB(), DummyD()) with pytest.raises(ValueError): mt.fuse_modules(model, types_sequence, fuse_fn, dummy_input=None, adjacency_map=None) dummy_input = torch.randn(10, 10) sg = SummaryGraph(deepcopy(model), dummy_input) adj_map = sg.adjacency_map() fused_dummy_input = mt.fuse_modules(deepcopy(model), types_sequence, fuse_fn, dummy_input=dummy_input, adjacency_map=None) compare_models(fused_dummy_input, fused_reference()) fused_pre_sg = mt.fuse_modules(deepcopy(model), types_sequence, fuse_fn, dummy_input=None, adjacency_map=adj_map) compare_models(fused_pre_sg, fused_reference()) fused_both = mt.fuse_modules(deepcopy(model), types_sequence, fuse_fn, dummy_input=dummy_input, adjacency_map=adj_map) compare_models(fused_both, fused_reference())
def fused_reference(): return WrappedSequential(DummyA(), nn.Identity(), nn.Identity())
def test_fuse_modules(parallel): input_shape = (10, 10) # Simple negative tests # Not Fusable model = WrappedSequential(DummyA(fuseable=False), DummyB(), DummyD()) fuse_and_check(model, deepcopy(model), input_shape, parallel) model = WrappedSequential(DummyA(), DummyB(fuseable=False), DummyD()) fuse_and_check(model, deepcopy(model), input_shape, parallel) model = WrappedSequential(DummyA(), DummyC(), DummyD(fuseable=False)) fuse_and_check(model, deepcopy(model), input_shape, parallel) # Wrong sequence model = WrappedSequential(DummyB()) fuse_and_check(model, deepcopy(model), input_shape, parallel) model = WrappedSequential(DummyB(), DummyD()) fuse_and_check(model, deepcopy(model), input_shape, parallel) model = WrappedSequential(DummyA(), DummyB(), DummyA(), DummyD()) fuse_and_check(model, deepcopy(model), input_shape, parallel) model = WrappedSequential(DummyA(), DummyB(), DummyC(), DummyD()) fuse_and_check(model, deepcopy(model), input_shape, parallel) # Simple positive tests # Simple sequence 1 model = WrappedSequential(DummyA(), DummyB(), DummyD()) fuse_and_check(model, fused_reference(), input_shape, parallel) # Simple sequence 2 model = WrappedSequential(DummyA(), DummyC(), DummyD()) fuse_and_check(model, fused_reference(), input_shape, parallel) # 2 sequences model = WrappedSequential(WrappedSequential(DummyA(), DummyB(), DummyD()), WrappedSequential(DummyA(), DummyC(), DummyD())) expected = WrappedSequential(fused_reference(), fused_reference()) fuse_and_check(model, expected, input_shape, parallel) # "Complex" tests # 2 sequences with wrong sequence between them model = WrappedSequential(WrappedSequential(DummyA(), DummyB(), DummyD()), DummyA(), DummyB(), WrappedSequential(DummyA(), DummyC(), DummyD())) expected = WrappedSequential(fused_reference(), DummyA(), DummyB(), fused_reference()) fuse_and_check(model, expected, input_shape, parallel) # "Complex" model class SplitJoinModel(nn.Module): def __init__(self, m1, m2): super(SplitJoinModel, self).__init__() self.split = Split(int(input_shape[0] / 2)) self.m1 = m1 self.m2 = m2 self.add = EltwiseAdd() def forward(self, x): # x1, x2 = self.split(x) y1 = self.m1(x) y2 = self.m2(x) return self.add(y1, y2) model = SplitJoinModel(WrappedSequential(DummyA(), DummyB(), DummyD()), WrappedSequential(DummyA(), DummyC(), DummyD())) expected = SplitJoinModel(fused_reference(), fused_reference()) fuse_and_check(model, expected, input_shape, parallel) # Node with multiple outputs model = BypassModel((DummyA(), DummyB()), DummyD()) fuse_and_check(model, deepcopy(model), input_shape, parallel)
adjacency_map=adj_map) compare_models(fused_both, fused_reference()) ############################################################################### # Test BN folding for inference ############################################################################### # This warning seems to be a bug in batch_norm implementation, which compares a tensor to the value 1 @pytest.mark.filterwarnings( 'ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect' ) @pytest.mark.parametrize( 'model, input_shape', [(WrappedSequential(nn.ReLU(), nn.BatchNorm1d(5)), (10, 5)), (WrappedSequential(nn.Conv1d(10, 20, 3), nn.ReLU()), (10, 10, 10)), (WrappedSequential(nn.Conv2d(10, 20, 3), nn.BatchNorm2d(20, track_running_stats=False)), (10, 10, 50, 50)), (WrappedSequential(nn.Linear( 10, 20), nn.BatchNorm1d(20, track_running_stats=False)), (10, 10)), (BypassModel( (nn.Conv2d(10, 20, 3), ), nn.BatchNorm2d(20)), (10, 10, 50, 50))], ids=[ 'relu->bn', 'conv->relu', 'conv->bn_no_stats', 'linear->bn_no_stats', 'conv_multiple_outputs->bn' ]) def test_fold_batch_norms_inference_no_fold(model, input_shape): orig_model = deepcopy(model) folded_model = mt.fold_batch_norms(model,