def test_layer_structure_exclude_layer(self): input_layer = layers.Input(10) connection = input_layer > layers.Sigmoid(1) graph = connection.graph.forward_graph actual_graph = exclude_layer_from_graph(graph, [layers.Sigmoid]) expected_graph = OrderedDict() expected_graph[input_layer] = [] self.assertEqual(expected_graph, actual_graph)
def test_layer_structure_exclude_layer_nothing_to_exclude(self): connection = layers.Input(10) > layers.Sigmoid(1) graph = connection.graph.forward_graph new_graph = exclude_layer_from_graph(graph, tuple()) self.assertEqual(graph, new_graph)