def test_wider_conv(): model = CnnGenerator(10, (28, 28, 3)).generate().produce_model() model.set_weight_to_graph() graph = model.graph assert isinstance(wider_pre_conv(graph.layer_list[2], 3), StubConv2d) assert isinstance(wider_bn(graph.layer_list[5], 3, 3, 3), StubBatchNormalization2d) assert isinstance(wider_next_conv(graph.layer_list[6], 3, 3, 3), StubConv2d)
def test_wider_dense(): graph = CnnGenerator(10, (32, 32, 3)).generate() graph.produce_model().set_weight_to_graph() history = [('to_wider_model', 14, 64)] for args in history: getattr(graph, args[0])(*list(args[1:])) graph.produce_model() assert graph.layer_list[14].output.shape[-1] == 128
def test_node_consistency(): graph = CnnGenerator(10, (32, 32, 3)).generate() assert graph.layer_list[6].output.shape == (16, 16, 64) for layer in graph.layer_list: assert layer.output.shape == layer.output_shape graph.to_wider_model(6, 64) assert graph.layer_list[6].output.shape == (16, 16, 128) for layer in graph.layer_list: assert layer.output.shape == layer.output_shape
def test_long_transform2(): graph = CnnGenerator(10, (28, 28, 1)).generate() graph.to_add_skip_model(2, 3) graph.to_concat_skip_model(2, 3) model = graph.produce_model() model(torch.Tensor(np.random.random((10, 1, 28, 28))))
def test_graph_size(): graph = CnnGenerator(10, (32, 32, 3)).generate() assert graph.size() == 7254