예제 #1
0
def test_wider_conv():
    model = CnnGenerator(10, (28, 28, 3)).generate().produce_model()
    model.set_weight_to_graph()
    graph = model.graph

    assert isinstance(wider_pre_conv(graph.layer_list[2], 3), StubConv2d)
    assert isinstance(wider_bn(graph.layer_list[5], 3, 3, 3), StubBatchNormalization2d)
    assert isinstance(wider_next_conv(graph.layer_list[6], 3, 3, 3), StubConv2d)
예제 #2
0
def test_wider_dense():
    graph = CnnGenerator(10, (32, 32, 3)).generate()
    graph.produce_model().set_weight_to_graph()
    history = [('to_wider_model', 14, 64)]
    for args in history:
        getattr(graph, args[0])(*list(args[1:]))
        graph.produce_model()
    assert graph.layer_list[14].output.shape[-1] == 128
예제 #3
0
def test_node_consistency():
    graph = CnnGenerator(10, (32, 32, 3)).generate()
    assert graph.layer_list[6].output.shape == (16, 16, 64)

    for layer in graph.layer_list:
        assert layer.output.shape == layer.output_shape

    graph.to_wider_model(6, 64)
    assert graph.layer_list[6].output.shape == (16, 16, 128)

    for layer in graph.layer_list:
        assert layer.output.shape == layer.output_shape
예제 #4
0
def test_long_transform2():
    graph = CnnGenerator(10, (28, 28, 1)).generate()
    graph.to_add_skip_model(2, 3)
    graph.to_concat_skip_model(2, 3)
    model = graph.produce_model()
    model(torch.Tensor(np.random.random((10, 1, 28, 28))))
예제 #5
0
def test_graph_size():
    graph = CnnGenerator(10, (32, 32, 3)).generate()
    assert graph.size() == 7254