def test_tf_depth_to_space_infer(self): graph = build_graph(nodes, edges) dts_node = Node(graph, 'DtS') DepthToSpaceOp.depth_to_space_infer(dts_node) exp_shape = np.array([1, 2048, 1152, 64]) res_shape = graph.node['out_data_node']['shape'] self.assertTrue(np.array_equal(exp_shape, res_shape))
def replace_pattern(graph: Graph, match: dict): channel_splitting_reshape = match['reshape_0'] channel_concating_reshape = match['reshape_1'] initial_shape = channel_splitting_reshape.in_port(0).data.get_shape() resulting_shape = channel_concating_reshape.in_port(1).data.get_value() if initial_shape[0] != resulting_shape[0]: return channel_splitted_out_shape = channel_splitting_reshape.in_port( 1).data.get_value() if not all([initial_shape[i] == channel_splitted_out_shape[j] for i, j in {0: 0, 2: 4, 3: 5}.items()]) or \ channel_splitted_out_shape[1] != channel_splitted_out_shape[2]: return block_size = channel_splitted_out_shape[2] expected_output_shape = [ initial_shape[0], initial_shape[1] // (block_size * block_size), initial_shape[2] * block_size, initial_shape[3] * block_size ] if not np.array_equal(expected_output_shape, resulting_shape): return name = channel_concating_reshape.soft_get('name', channel_concating_reshape.id) depth_to_space = DepthToSpaceOp(graph, { 'name': name, 'block_size': block_size, 'mode': 'depth_first' }).create_node() channel_concating_reshape.out_port(0).get_connection().set_source( depth_to_space.out_port(0)) depth_to_space.in_port(0).connect( channel_splitting_reshape.in_port(0).get_source())
def extract(cls, node): # update the attributes of the node block_size = node.pb.attr['block_size'].i data_format = node.pb.attr['data_format'].s.decode('utf-8') DepthToSpaceOp.update_node_stat(node, { 'block_size': block_size, 'data_format': data_format }) return cls.enabled
def test_tf_depth_to_space_infer_nchw(self): graph = build_graph(nodes, edges) graph.graph['layout'] = 'NCHW' graph.node['in_data_node']['shape'] = np.array([1, 256, 1024, 576]) dts_node = Node(graph, 'DtS') DepthToSpaceOp.infer(dts_node) exp_shape = np.array([1, 64, 2048, 1152]) res_shape = graph.node['out_data_node']['shape'] self.assertTrue(np.array_equal(exp_shape, res_shape))
def extract(cls, node): # update the attributes of the node node_name = node.soft_get('name', node.id) block_size = onnx_attr(node, 'blocksize', 'i', default=None) assert block_size is not None, \ 'DepthToSpace should have "blocksize" attribute specified for node {}'.format(node_name) onnx_mode = onnx_attr(node, 'mode', 's', default=b'DCR').decode() assert onnx_mode in [ 'DCR', 'CRD' ], 'Unrecognized mode provided for DepthToSpace node {}'.format( node_name) if onnx_mode == 'DCR': mode = 'blocks_first' else: mode = 'depth_first' DepthToSpaceOp.update_node_stat(node, { 'block_size': block_size, 'mode': mode }) return cls.enabled
def test_tf_depth_to_space_infer_error_1(self): graph = build_graph(nodes, edges) graph.node['in_data_node']['shape'] = np.array([1, 1024, 576, 255]) dts_node = Node(graph, 'DtS') self.assertRaises(Error, DepthToSpaceOp.depth_to_space_infer(dts_node))