Exemplo n.º 1
0
    def replace_pattern(graph: Graph, match: dict):
        channel_splitting_reshape = match['reshape_0']
        channel_concating_reshape = match['reshape_1']

        initial_shape = channel_splitting_reshape.in_port(0).data.get_shape()
        resulting_shape = channel_concating_reshape.in_port(1).data.get_value()
        if initial_shape[0] != resulting_shape[0]:
            return

        channel_splitted_out_shape = channel_splitting_reshape.in_port(
            1).data.get_value()
        if not all([initial_shape[i] == channel_splitted_out_shape[j] for i, j in {0: 0, 2: 4, 3: 5}.items()]) or \
                channel_splitted_out_shape[1] != channel_splitted_out_shape[2]:
            return
        block_size = channel_splitted_out_shape[2]
        expected_output_shape = [
            initial_shape[0], initial_shape[1] // (block_size * block_size),
            initial_shape[2] * block_size, initial_shape[3] * block_size
        ]
        if not np.array_equal(expected_output_shape, resulting_shape):
            return

        name = channel_concating_reshape.soft_get('name',
                                                  channel_concating_reshape.id)
        depth_to_space = DepthToSpaceOp(graph, {
            'name': name,
            'block_size': block_size,
            'mode': 'depth_first'
        }).create_node()
        channel_concating_reshape.out_port(0).get_connection().set_source(
            depth_to_space.out_port(0))
        depth_to_space.in_port(0).connect(
            channel_splitting_reshape.in_port(0).get_source())
Exemplo n.º 2
0
 def test_tf_depth_to_space_infer_nhwc(self):
     graph = build_graph(nodes, edges)
     graph.graph['layout'] = 'NHWC'
     dts_node = Node(graph, 'DtS')
     DepthToSpaceOp.infer(dts_node)
     exp_shape = np.array([1, 2048, 1152, 64])
     res_shape = graph.node['out_data_node']['shape']
     self.assertTrue(np.array_equal(exp_shape, res_shape))
Exemplo n.º 3
0
 def extract(cls, node):
     # update the attributes of the node
     block_size = node.pb.attr['block_size'].i
     data_format = node.pb.attr['data_format'].s.decode('utf-8')
     DepthToSpaceOp.update_node_stat(node, {
         'block_size': block_size,
         'data_format': data_format
     })
     return cls.enabled
Exemplo n.º 4
0
    def extract(cls, node):
        # update the attributes of the node
        node_name = node.soft_get('name', node.id)
        block_size = onnx_attr(node, 'blocksize', 'i', default=None)
        assert block_size is not None, \
            'DepthToSpace should have "blocksize" attribute specified for node {}'.format(node_name)
        onnx_mode = onnx_attr(node, 'mode', 's', default=b'DCR').decode()
        assert onnx_mode in ['DCR', 'CRD'], 'Unrecognized mode provided for DepthToSpace node {}'.format(node_name)
        if onnx_mode == 'DCR':
            mode = 'blocks_first'
        else:
            mode = 'depth_first'

        DepthToSpaceOp.update_node_stat(node, {'block_size': block_size, 'mode': mode})
        return cls.enabled