def test_wider_dense(): graph = CnnGenerator(10, (32, 32, 3)).generate() graph.produce_model().set_weight_to_graph() history = [('to_wider_model', 14, 64)] for args in history: getattr(graph, args[0])(*list(args[1:])) graph.produce_model() assert graph.layer_list[14].output.shape[-1] == 128
def test_long_transform(): graph = CnnGenerator(10, (32, 32, 3)).generate() history = [('to_wider_model', 1, 256), ('to_conv_deeper_model', 1, 3), ('to_concat_skip_model', 5, 9)] for args in history: getattr(graph, args[0])(*list(args[1:])) graph.produce_model() assert legal_graph(graph)
def test_wider_dense(): graph = CnnGenerator(10, (32, 32, 3)).generate() graph.produce_model().set_weight_to_graph() history = [('to_wider_model', 14, 64)] for args in history: getattr(graph, args[0])(*list(args[1:])) graph.produce_model() assert legal_graph(graph)
def test_long_transform2(): graph = CnnGenerator(10, (28, 28, 1)).generate() graph.to_add_skip_model(2, 3) graph.to_concat_skip_model(2, 3) model = graph.produce_model() model(torch.Tensor(np.random.random((10, 1, 28, 28))))