Exemplo n.º 1
0
 def test_simple_convolution(self):
     graph = build_graph(self.nodes_attributes, [
         ('conv', 'reshape_conv'),
         ('reshape_conv', 'scale_shift'),
     ])
     ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
     conv_node = Node(graph, graph.nodes['conv']['name'])
     permute = conv_node.out_node()
     self.assertEqual(permute.op, 'Transpose')
     self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
Exemplo n.º 2
0
 def test_conv_pool(self):
     graph = build_graph(self.nodes_attributes, [
         ('conv', 'reshape_conv'),
         ('reshape_conv', 'reshape_pool'),
         ('reshape_pool', 'pool'),
         ('pool', 'reshape_after_pool'),
         ('reshape_after_pool', 'fc'),
     ])
     ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
     pool_node = Node(graph, graph.nodes['pool']['name'])
     permute = pool_node.out_node()
     self.assertEqual(permute.op, 'Transpose')
     self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))