def test_tf_space_to_depth_infer_nhwc(self): graph = build_graph(nodes, edges) graph.graph['layout'] = 'NHWC' std_node = Node(graph, 'StD') SpaceToDepth.infer(std_node) exp_shape = np.array([1, 1024, 576, 256]) res_shape = graph.node['out_data_node']['shape'] self.assertTrue(np.array_equal(exp_shape, res_shape))
def extract(cls, node): # update the attributes of the node block_size = node.pb.attr['block_size'].i data_format = node.pb.attr['data_format'].s.decode('utf-8') SpaceToDepth.update_node_stat(node, { 'block_size': block_size, 'data_format': data_format }) return cls.enabled
def test_tf_space_to_depth_infer_nchw(self): graph = build_graph(nodes, edges) graph.graph['layout'] = 'NCHW' graph.node['in_data_node']['shape'] = np.array([1, 64, 2048, 1152]) std_node = Node(graph, 'StD') SpaceToDepth.infer(std_node) exp_shape = np.array([1, 256, 1024, 576]) res_shape = graph.node['out_data_node']['shape'] self.assertTrue(np.array_equal(exp_shape, res_shape))
def test_tf_space_to_depth_infer_nchw_dynamic(self): graph = build_graph(nodes, edges) graph.graph['layout'] = 'NCHW' graph.node['in_data_node']['shape'] = shape_array( [1, 64, dynamic_dimension_value, 1152]) std_node = Node(graph, 'StD') SpaceToDepth.infer(std_node) exp_shape = shape_array([1, 256, dynamic_dimension_value, 576]) res_shape = graph.node['out_data_node']['shape'] self.assertTrue(strict_compare_tensors(exp_shape, res_shape))
def extract(cls, node): # update the attributes of the node block_size = onnx_attr(node, 'blocksize', 'i', default=None) SpaceToDepth.update_node_stat(node, {'block_size': block_size}) return cls.enabled