def replace_sub_graph(self, graph: Graph, match: dict): node = match['node'] connected_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] squeeze_node = Squeeze(graph, dict()).create_node([], dict(name=node.name + '/Squeeze')) if len(connected_ports) == 2: node.in_port(1).get_source().connect(squeeze_node.in_port(1)) else: axis_node = Const(graph, {'value': node.axis}).create_node() squeeze_node.in_port(1).connect(axis_node.out_port(0)) node.out_port(0).get_connection().set_source(squeeze_node.out_port(0)) node.out_port(0).connect(squeeze_node.in_port(0)) return []
def add_squeeze_for_shrink(graph: Graph, ss_node: Node): # add Squeeze for shrink_axis_mask log.info( "StridedSlice op with shrink mask '{}' has been detected".format( ss_node.id)) if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1: return shape_out = ss_node.out_node().shape dim = mo_array(range(len(ss_node['shrink_axis_mask'])))[mo_array( ss_node['shrink_axis_mask'], dtype=bool)] ss_shape = [] i = 0 k = 0 # Don't permute reshape if channels were squeezed dont_permute = graph.graph['layout'] == 'NCHW' if graph.graph['layout'] == 'NHWC' and ss_node['shrink_axis_mask'][ -1] == 1: dont_permute = True while k < len(shape_out): if i >= len(ss_node['shrink_axis_mask'] ) or not ss_node['shrink_axis_mask'][i]: ss_shape.append(shape_out[k]) k = k + 1 else: ss_node['shrink_axis_mask'][i] = 0 ss_shape.append(1) i = i + 1 while i < len(ss_node['shrink_axis_mask']): ss_node['shrink_axis_mask'][i] = 0 ss_shape.append(1) i = i + 1 ss_node.out_port(0).data.set_shape(ss_shape) # insert Squeeze squeeze_node = Squeeze( graph, dict(name=ss_node.name + '/Squeeze_shrink', nchw_layout=dont_permute, correct_data_layout=dont_permute)).create_node() ss_node.out_port(0).get_connection().insert_node(squeeze_node) squeeze_node.out_port(0).data.set_shape(shape_out) dims_node = Const(graph, { 'name': squeeze_node.id + '/Indices', 'value': int64_array(dim) }).create_node() dims_node.out_port(0).connect(squeeze_node.in_port(1))
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(squeeze_axis=True): name = node.soft_get('name', node.id) for out_port in node.out_ports().values(): if node.has_valid('axis'): squeeze_node = create_op_with_const_inputs( graph, Squeeze, {1: mo_array(node.axis)}, {'name': name + '/Squeeze_'}) out_port.get_connection().insert_node(squeeze_node) elif node.is_in_port_connected(1): squeeze_node = Squeeze(graph, { 'name': name + '/Squeeze_' }).create_node() out_port.get_connection().insert_node(squeeze_node) node.in_port(1).get_connection().add_destination( squeeze_node.in_port(1)) else: raise Error( 'Unknown axis to squeeze for node {}'.format(name))