def test_simple_resnet_graph():
    simple_archi: Architecture = Architecture(
        preprocess=lambda x: x,
        layers=[
            LinearLayer(4, 4),
            LinearLayer(4, 4),
            LinearLayer(4, 4),
            LinearLayer(4, 10),
            SoftMaxLayer(),
        ],
        layer_links=[(-1, 0), (0, 1), (1, 2), (1, 3), (2, 3), (3, 4)],
    )
    simple_archi.build_matrices()

    simple_example = torch.ones(4)

    graph = Graph.from_architecture_and_data_point(simple_archi, simple_example)
    adjacency_matrix = graph.get_adjacency_matrix().todense()

    assert np.shape(adjacency_matrix) == (26, 26)

    assert len(graph.get_edge_list()) == 128
    assert len(np.where(adjacency_matrix > 0)[0]) == 128 * 2

    print(graph.get_edge_list())
    print(simple_archi.get_pre_softmax_idx())
 ),
 BatchNorm2d(channels=512, activ=F.relu),
 ConvLayer(
     in_channels=512,
     out_channels=512,
     kernel_size=3,
     stride=1,
     padding=1,
     bias=False,
 ),
 BatchNorm2d(channels=512),
 ReluLayer(),
 # End part
 AvgPool2dLayer(kernel_size=4),
 LinearLayer(512, 10),
 SoftMaxLayer(),
 # Layer to reduce dimension in residual blocks
 ConvLayer(
     in_channels=64,
     out_channels=128,
     kernel_size=1,
     stride=2,
     padding=0,
     bias=False,
 ),
 ConvLayer(
     in_channels=128,
     out_channels=256,
     kernel_size=1,
     stride=2,
     padding=0,