Example #1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['node']

        connected_ports = [
            port for port in node.in_ports().values()
            if not port.disconnected()
        ]
        squeeze_node = Squeeze(graph, dict()).create_node([],
                                                          dict(name=node.name +
                                                               '/Squeeze'))
        if len(connected_ports) == 2:
            node.in_port(1).get_source().connect(squeeze_node.in_port(1))
        else:
            axis_node = Const(graph, {'value': node.axis}).create_node()
            squeeze_node.in_port(1).connect(axis_node.out_port(0))
        node.out_port(0).get_connection().set_source(squeeze_node.out_port(0))
        node.out_port(0).connect(squeeze_node.in_port(0))
        return []
    def add_squeeze_for_shrink(graph: Graph, ss_node: Node):
        # add Squeeze for shrink_axis_mask
        log.info(
            "StridedSlice op with shrink mask '{}' has been detected".format(
                ss_node.id))

        if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1:
            return

        shape_out = ss_node.out_node().shape
        dim = mo_array(range(len(ss_node['shrink_axis_mask'])))[mo_array(
            ss_node['shrink_axis_mask'], dtype=bool)]
        ss_shape = []
        i = 0
        k = 0

        # Don't permute reshape if channels were squeezed
        dont_permute = graph.graph['layout'] == 'NCHW'
        if graph.graph['layout'] == 'NHWC' and ss_node['shrink_axis_mask'][
                -1] == 1:
            dont_permute = True

        while k < len(shape_out):
            if i >= len(ss_node['shrink_axis_mask']
                        ) or not ss_node['shrink_axis_mask'][i]:
                ss_shape.append(shape_out[k])
                k = k + 1
            else:
                ss_node['shrink_axis_mask'][i] = 0
                ss_shape.append(1)
            i = i + 1

        while i < len(ss_node['shrink_axis_mask']):
            ss_node['shrink_axis_mask'][i] = 0
            ss_shape.append(1)
            i = i + 1

        ss_node.out_port(0).data.set_shape(ss_shape)

        # insert Squeeze
        squeeze_node = Squeeze(
            graph,
            dict(name=ss_node.name + '/Squeeze_shrink',
                 nchw_layout=dont_permute,
                 correct_data_layout=dont_permute)).create_node()
        ss_node.out_port(0).get_connection().insert_node(squeeze_node)
        squeeze_node.out_port(0).data.set_shape(shape_out)

        dims_node = Const(graph, {
            'name': squeeze_node.id + '/Indices',
            'value': int64_array(dim)
        }).create_node()
        dims_node.out_port(0).connect(squeeze_node.in_port(1))
Example #3
0
 def find_and_replace_pattern(self, graph: Graph):
     for node in graph.get_op_nodes(squeeze_axis=True):
         name = node.soft_get('name', node.id)
         for out_port in node.out_ports().values():
             if node.has_valid('axis'):
                 squeeze_node = create_op_with_const_inputs(
                     graph, Squeeze, {1: mo_array(node.axis)},
                     {'name': name + '/Squeeze_'})
                 out_port.get_connection().insert_node(squeeze_node)
             elif node.is_in_port_connected(1):
                 squeeze_node = Squeeze(graph, {
                     'name': name + '/Squeeze_'
                 }).create_node()
                 out_port.get_connection().insert_node(squeeze_node)
                 node.in_port(1).get_connection().add_destination(
                     squeeze_node.in_port(1))
             else:
                 raise Error(
                     'Unknown axis to squeeze for node {}'.format(name))