def test_fuse_modules_with_pre_exist_adj_map():
    model = WrappedSequential(DummyA(), DummyB(), DummyD())
    with pytest.raises(ValueError):
        mt.fuse_modules(model,
                        types_sequence,
                        fuse_fn,
                        dummy_input=None,
                        adjacency_map=None)

    dummy_input = torch.randn(10, 10)
    sg = SummaryGraph(deepcopy(model), dummy_input)
    adj_map = sg.adjacency_map()

    fused_dummy_input = mt.fuse_modules(deepcopy(model),
                                        types_sequence,
                                        fuse_fn,
                                        dummy_input=dummy_input,
                                        adjacency_map=None)
    compare_models(fused_dummy_input, fused_reference())

    fused_pre_sg = mt.fuse_modules(deepcopy(model),
                                   types_sequence,
                                   fuse_fn,
                                   dummy_input=None,
                                   adjacency_map=adj_map)
    compare_models(fused_pre_sg, fused_reference())

    fused_both = mt.fuse_modules(deepcopy(model),
                                 types_sequence,
                                 fuse_fn,
                                 dummy_input=dummy_input,
                                 adjacency_map=adj_map)
    compare_models(fused_both, fused_reference())
def fused_reference():
    return WrappedSequential(DummyA(), nn.Identity(), nn.Identity())
def test_fuse_modules(parallel):
    input_shape = (10, 10)

    # Simple negative tests

    # Not Fusable
    model = WrappedSequential(DummyA(fuseable=False), DummyB(), DummyD())
    fuse_and_check(model, deepcopy(model), input_shape, parallel)

    model = WrappedSequential(DummyA(), DummyB(fuseable=False), DummyD())
    fuse_and_check(model, deepcopy(model), input_shape, parallel)

    model = WrappedSequential(DummyA(), DummyC(), DummyD(fuseable=False))
    fuse_and_check(model, deepcopy(model), input_shape, parallel)

    # Wrong sequence
    model = WrappedSequential(DummyB())
    fuse_and_check(model, deepcopy(model), input_shape, parallel)

    model = WrappedSequential(DummyB(), DummyD())
    fuse_and_check(model, deepcopy(model), input_shape, parallel)

    model = WrappedSequential(DummyA(), DummyB(), DummyA(), DummyD())
    fuse_and_check(model, deepcopy(model), input_shape, parallel)

    model = WrappedSequential(DummyA(), DummyB(), DummyC(), DummyD())
    fuse_and_check(model, deepcopy(model), input_shape, parallel)

    # Simple positive tests

    # Simple sequence 1
    model = WrappedSequential(DummyA(), DummyB(), DummyD())
    fuse_and_check(model, fused_reference(), input_shape, parallel)

    # Simple sequence 2
    model = WrappedSequential(DummyA(), DummyC(), DummyD())
    fuse_and_check(model, fused_reference(), input_shape, parallel)

    # 2 sequences
    model = WrappedSequential(WrappedSequential(DummyA(), DummyB(), DummyD()),
                              WrappedSequential(DummyA(), DummyC(), DummyD()))
    expected = WrappedSequential(fused_reference(), fused_reference())
    fuse_and_check(model, expected, input_shape, parallel)

    # "Complex" tests

    # 2 sequences with wrong sequence between them
    model = WrappedSequential(WrappedSequential(DummyA(), DummyB(), DummyD()),
                              DummyA(), DummyB(),
                              WrappedSequential(DummyA(), DummyC(), DummyD()))
    expected = WrappedSequential(fused_reference(), DummyA(), DummyB(),
                                 fused_reference())
    fuse_and_check(model, expected, input_shape, parallel)

    # "Complex" model
    class SplitJoinModel(nn.Module):
        def __init__(self, m1, m2):
            super(SplitJoinModel, self).__init__()
            self.split = Split(int(input_shape[0] / 2))
            self.m1 = m1
            self.m2 = m2
            self.add = EltwiseAdd()

        def forward(self, x):
            # x1, x2 = self.split(x)
            y1 = self.m1(x)
            y2 = self.m2(x)
            return self.add(y1, y2)

    model = SplitJoinModel(WrappedSequential(DummyA(), DummyB(), DummyD()),
                           WrappedSequential(DummyA(), DummyC(), DummyD()))
    expected = SplitJoinModel(fused_reference(), fused_reference())
    fuse_and_check(model, expected, input_shape, parallel)

    # Node with multiple outputs
    model = BypassModel((DummyA(), DummyB()), DummyD())
    fuse_and_check(model, deepcopy(model), input_shape, parallel)
                                 adjacency_map=adj_map)
    compare_models(fused_both, fused_reference())


###############################################################################
# Test BN folding for inference
###############################################################################


# This warning seems to be a bug in batch_norm implementation, which compares a tensor to the value 1
@pytest.mark.filterwarnings(
    'ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect'
)
@pytest.mark.parametrize(
    'model, input_shape',
    [(WrappedSequential(nn.ReLU(), nn.BatchNorm1d(5)), (10, 5)),
     (WrappedSequential(nn.Conv1d(10, 20, 3), nn.ReLU()), (10, 10, 10)),
     (WrappedSequential(nn.Conv2d(10, 20, 3),
                        nn.BatchNorm2d(20, track_running_stats=False)),
      (10, 10, 50, 50)),
     (WrappedSequential(nn.Linear(
         10, 20), nn.BatchNorm1d(20, track_running_stats=False)), (10, 10)),
     (BypassModel(
         (nn.Conv2d(10, 20, 3), ), nn.BatchNorm2d(20)), (10, 10, 50, 50))],
    ids=[
        'relu->bn', 'conv->relu', 'conv->bn_no_stats', 'linear->bn_no_stats',
        'conv_multiple_outputs->bn'
    ])
def test_fold_batch_norms_inference_no_fold(model, input_shape):
    orig_model = deepcopy(model)
    folded_model = mt.fold_batch_norms(model,