def test_tf_space_to_depth_infer_nhwc(self):
     graph = build_graph(nodes, edges)
     graph.graph['layout'] = 'NHWC'
     std_node = Node(graph, 'StD')
     SpaceToDepth.infer(std_node)
     exp_shape = np.array([1, 1024, 576, 256])
     res_shape = graph.node['out_data_node']['shape']
     self.assertTrue(np.array_equal(exp_shape, res_shape))
 def test_tf_space_to_depth_infer_nchw(self):
     graph = build_graph(nodes, edges)
     graph.graph['layout'] = 'NCHW'
     graph.node['in_data_node']['shape'] = np.array([1, 64, 2048, 1152])
     std_node = Node(graph, 'StD')
     SpaceToDepth.infer(std_node)
     exp_shape = np.array([1, 256, 1024, 576])
     res_shape = graph.node['out_data_node']['shape']
     self.assertTrue(np.array_equal(exp_shape, res_shape))
 def extract(cls, node):
     # update the attributes of the node
     block_size = node.pb.attr['block_size'].i
     data_format = node.pb.attr['data_format'].s.decode('utf-8')
     SpaceToDepth.update_node_stat(node, {
         'block_size': block_size,
         'data_format': data_format
     })
     return cls.enabled
Exemple #4
0
 def test_tf_space_to_depth_infer_nchw_dynamic(self):
     graph = build_graph(nodes, edges)
     graph.graph['layout'] = 'NCHW'
     graph.node['in_data_node']['shape'] = shape_array(
         [1, 64, dynamic_dimension_value, 1152])
     std_node = Node(graph, 'StD')
     SpaceToDepth.infer(std_node)
     exp_shape = shape_array([1, 256, dynamic_dimension_value, 576])
     res_shape = graph.node['out_data_node']['shape']
     self.assertTrue(strict_compare_tensors(exp_shape, res_shape))
Exemple #5
0
 def extract(cls, node):
     # update the attributes of the node
     block_size = onnx_attr(node, 'blocksize', 'i', default=None)
     SpaceToDepth.update_node_stat(node, {'block_size': block_size})
     return cls.enabled