def test_simple_convolution(self): graph = build_graph(self.nodes_attributes, [ ('conv', 'reshape_conv'), ('reshape_conv', 'scale_shift'), ]) ReplaceConvolutionTranspose().find_and_replace_pattern(graph) conv_node = Node(graph, graph.nodes['conv']['name']) permute = conv_node.out_node() self.assertEqual(permute.op, 'Transpose') self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
def test_conv_pool(self): graph = build_graph(self.nodes_attributes, [ ('conv', 'reshape_conv'), ('reshape_conv', 'reshape_pool'), ('reshape_pool', 'pool'), ('pool', 'reshape_after_pool'), ('reshape_after_pool', 'fc'), ]) ReplaceConvolutionTranspose().find_and_replace_pattern(graph) pool_node = Node(graph, graph.nodes['pool']['name']) permute = pool_node.out_node() self.assertEqual(permute.op, 'Transpose') self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))