def order(op_node: Node, port_info: str, input_port: int): """ Performs layout change related transformation of the data on the in_port_idx port of op_node. Translates ordered shape indexes from one layout to another according to permutation Transformation inserts two Gather operations 1 Gather reorders data to new layout according to direct permutation: actual data to translate as 1-port input indexes of Gather and permutation as 0-port input data 2 Gather translates shape indexes from one layout to another according to inverse permutation permutation as 0-port input data and actual data to translate as 1-port input indexes of Gather For example: NHWC Transpose operation has 0-port input with data of shape [1, 2, 3, 4] and 1-port input with new order indices [0, 1, 3, 2]. After translating such operation to NCHW layout: 0-port input shape = [1, 4, 2, 3] 1 phase (after first Gather insertion): 1-port input order indices = [0, 2, 1, 3] 2 phase (after second Gather insertion): 1-port input order indices = [0, 3, 2, 1] """ graph = op_node.graph permutation_data_node = get_node_with_permutation(op_node, port_info) assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \ 'port_info "{}".'.format(permutation_data_node.id, op_node.id, port_info) permutation = permutation_data_node.permutation if len(permutation.perm) == 0: return data_node = op_node.in_node(input_port) const = Const(graph, { 'value': permutation.perm, 'need_shape_inference': True }).create_node_with_data() gather = Gather(graph, { 'name': op_node.name + '/OrderGather_1', 'need_shape_inference': True }).create_node_with_data([data_node, const]) const_1 = Const(graph, { 'value': permutation.inv, 'need_shape_inference': True }).create_node_with_data() gather_1 = Gather(graph, { 'name': op_node.name + '/OrderGather_2', 'need_shape_inference': True }).create_node_with_data([const_1, gather]) attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy() graph.add_edge(gather_1.id, op_node.id, **attrs) graph.remove_edge(data_node.id, op_node.id) op_node['need_shape_inference'] = True
def extract(node): attrs = { 'axis': np.array(onnx_attr(node, 'axis', 'i', default=0), dtype=np.int64) } Gather.update_node_stat(node, attrs) return __class__.enabled
def replace_op(self, graph: Graph, node: Node): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 const_attrs = { 'name': 'indexes/{}'.format(node.id), 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) perm_in_1 = Const( graph, { 'value': np.array([1, 0], dtype=np.int64), 'shape': [2], 'data_type': np.int64 }).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() perm1_node = Transpose(graph, { 'name': 'input_permute' }).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) gather_node = Gather(graph, {}).create_node() gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) gather_node.in_port(2).connect(axis_const.out_port(0)) perm2_node = Transpose(graph, {'name': 'output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) return [perm2_node.id]
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) assert node.has_valid('axis') axis = Const(graph, {'name': name + '/axis', 'value': int64_array(node.axis)}).create_node() gather = Gather(graph, {'name': name}).create_node() node.in_port(0).get_connection().set_destination(gather.in_port(0)) node.in_port(1).get_connection().set_destination(gather.in_port(1)) axis.out_port(0).connect(gather.in_port(2)) node.out_port(0).get_connection().set_source(gather.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] gather_node = Gather( graph, dict(name=node.id + '/embedding_', axis=0, symbol_dict={'name': node.id + '/embedding_'})).create_node() node.in_port(0).get_connection().set_destination( gather_node.in_port(1)) node.in_port(1).get_connection().set_destination( gather_node.in_port(0)) node.out_port(0).get_connection().set_source(gather_node.out_port(0))
def test_gather_infer(self): graph = self._create_graph() gather_node = Node(graph, 'gather_node') Gather.infer(gather_node) exp_shape = int64_array([2, 15]) res_shape = graph.node['gather_output']['shape'] res_value = graph.node['gather_output']['value'] self.assertTrue(np.array_equal(exp_shape, res_shape), 'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape)) self.assertTrue(np.array_equal(res_value, np.ones(exp_shape)), 'shapes do not match expected: {} and given: {}'.format(exp_shape, res_shape))
def shape(op_node: Node, port_info: str, input_port: int): graph = op_node.graph permutation_data_node = get_node_with_permutation(op_node, port_info) assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \ 'port_info "{}".'.format(permutation_data_node.id, op_node.id, port_info) permutation = permutation_data_node.permutation if len(permutation.perm) == 0: return data_node = op_node.in_node(input_port) gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather' const = Const( graph, { 'value': permutation.perm, 'name': gather_name + '/const', 'need_shape_inference': True }).create_node_with_data() axis_const = Const(graph, { 'value': int64_array(0), 'name': gather_name + '/axis' }).create_node_with_data() gather = Gather(graph, { 'name': gather_name, 'need_shape_inference': True }).create_node_with_data([data_node, const, axis_const]) attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy() graph.add_edge(gather.id, op_node.id, **attrs) graph.remove_edge(data_node.id, op_node.id) # need to run manually to override output shape value to resolve shape collision for nodes with # 'correct_data_layout' output port attrs op_node['need_shape_inference'] = True
def replace_op(self, graph: Graph, node: Node): axis = Const(graph, {'value': 0}).create_node() inputs = [node.in_node(1), # weight node.in_node(0), # input_ids axis] gather = Gather(graph, dict(name=node.name)).create_node(inputs) return [gather.id]
def get_shape_values_by_indices_node(shape_node: Node, indices_node: Node) -> Node: """ The function returns a node that produces values of the specified indices node of the input node 'shape_node' :param shape_node: the node of 1D output shape to get elements from :param indices_node: the node of 1D output shape with the list of element indices to get :return: node producing required elements of the node """ graph = shape_node.graph axis = Const(graph, {'value': int64_array(0), 'name': shape_node.name + '/Axis'}).create_node() gather_node = Gather(graph, {'name': shape_node.name + '/Gather'}).create_node() shape_node.out_port(0).connect(gather_node.in_port(0)) indices_node.out_port(0).connect(gather_node.in_port(1)) axis.out_port(0).connect(gather_node.in_port(2)) return gather_node
def replace_pattern(graph: Graph, match: dict): node = match['op'] if not node.has_port('in', 2) or node.in_port( 2).disconnected() or not node.has_and_set('shape_input'): return if node.has_valid('layout') and not node.layout.startswith( 'NC') and graph.graph['layout'] == 'NCHW': input_shape_rank = len(node.in_port(0).data.get_shape()) permutation = PermuteAttrs.get_nhwc_to_nchw_permutation( input_shape_rank) data_node = node.in_node(2) name = node.soft_get('name', node.id) + '/ShapeGather' const = Const( graph, { 'value': permutation.perm, 'name': name + '/Const', 'need_shape_inference': True }).create_node_with_data() axis_const = Const(graph, { 'value': int64_array(0), 'name': name + '/Axis' }).create_node_with_data() gather = Gather(graph, { 'name': name, 'need_shape_inference': True }).create_node_with_data([data_node, const, axis_const]) attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy() graph.add_edge(gather.id, node.id, **attrs) graph.remove_edge(data_node.id, node.id)
def reorder_inputs_for_shape_or_slice(op_node: Node, input_port: int, permute_indices_for_gather: list): """ axis and slice permutations are almost the same the only difference is that for slice in general case permutation depends from slice_rank not from input_rank or output_rank """ graph = op_node.graph data_node = op_node.in_node(input_port) gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather' const = Const( graph, { 'value': permute_indices_for_gather, 'name': gather_name + '/const', 'need_shape_inference': True }).create_node_with_data() axis_const = Const(graph, { 'value': int64_array(0), 'name': gather_name + '/axis' }).create_node_with_data() gather = Gather(graph, { 'name': gather_name, 'need_shape_inference': True }).create_node_with_data([data_node, const, axis_const]) attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy() graph.add_edge(gather.id, op_node.id, **attrs) graph.remove_edge(data_node.id, op_node.id) # need to run manually to override output shape value to resolve shape collision for nodes with # 'correct_data_layout' output port attrs op_node['need_shape_inference'] = True
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def replace_pattern(self, graph: Graph, match: dict): gather = match['GatherNd'] input_shape = gather.in_node(0).shape indices = gather.in_node(1).value if indices is None: # We can't do such special pass without indices value return # 0. All needed checks that we can replace GatherNd by Gather gather_idx = self.indices_check(indices, input_shape) if gather_idx is None: log.warning( 'Node {} with op=GatherNd can\'t be normalized to op=Gather.'. format(gather.name)) return # 1. Add Reshape and connect new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:])) reshape = Reshape(graph, { 'name': gather.name + '/Reshape_for_GatherNd/' }).create_node() reshape_const_node = Const(graph, { 'name': reshape.name + '/Dim', 'value': new_shape }).create_node() gather.in_port(0).get_connection().set_destination(reshape.in_port(0)) reshape.in_port(1).connect(reshape_const_node.out_port(0)) # 2. Change indices from Nd to 1d: new_indices = np.reshape( np.take(indices, indices=[gather_idx], axis=-1), [-1]) new_indices_const = Const(graph, dict(value=new_indices)).create_node() # 3. Create new Gather operation and reconnect all inputs/outputs new_gather = Gather(graph, { 'name': gather.name + '/NewGather/', 'axis': 0 }).create_node() reshape.out_port(0).connect(new_gather.in_port(0)) new_indices_const.out_port(0).connect(new_gather.in_port(1)) gather.out_port(0).get_connection().set_source(new_gather.out_port(0)) # 4. Remove old Gather node graph.remove_node(gather.id)
def node_to_get_shape_value_of_range(shape_node: Node, indices: list): """ The function returns a node that produces values of the specified indices of the input node 'shape_node' :param shape_node: the node of 1D output shape to get elements from :param indices: the list of element indices to get :return: node producing required elements of the node """ graph = shape_node.graph indices_node = Const(graph, { 'value': int64_array(indices), 'name': shape_node.name + '/Indices' }).create_node() gather_node = Gather(graph, { 'name': shape_node.name + '/Gather' }).create_node() shape_node.out_port(0).connect(gather_node.in_port(0)) indices_node.out_port(0).connect(gather_node.in_port(1)) return gather_node
def replace_with_gather(node): graph = node.graph name = node.soft_get('name', node.id) axis = node.axis order = node.order indices = Const(graph, {'name': name + '/reverse_order', 'value': order}).create_node() axis_const = Const(graph, {'value': int64_array(axis)}).create_node() gather = Gather(graph, {'name': name}).create_node() gather.in_port(1).connect(indices.out_port(0)) gather.in_port(2).connect(axis_const.out_port(0)) node.out_port(0).get_connection().set_source(gather.out_port(0)) node.in_port(0).get_connection().set_destination(gather.in_port(0))
def axis(op_node: Node, port_info: str, input_port: int): """ Performs layout change related transformation of the data on the in_port_idx port of op_node. Translates shape indexes from one layout to another according to inverse permutation Transformation inserts Gather operation with permutation as 0-port input data and actual data to translate as 1-port input indexes of Gather For example: NHWC Reduce operation has 0-port input with data of shape [1, 2, 3, 4] and 1-port input with axis indices [0, 1]. After translating such operation to NCHW layout: 0-port input shape = [1, 4, 2, 3] 1-port input axis indices = [0, 2] """ graph = op_node.graph permutation_data_node = get_node_with_permutation(op_node, port_info) assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \ 'port_info "{}".'.format(permutation_data_node.id, op_node.id, port_info) permutation = permutation_data_node.permutation if len(permutation.perm) == 0: return data_node = op_node.in_node(input_port) gather_name = op_node.soft_get('name', op_node.id) + '/AxisGather' const = Const( graph, { 'value': permutation.inv, 'name': gather_name + '/const', 'need_shape_inference': True }).create_node_with_data() axis_const = Const(graph, { 'value': int64_array(0), 'name': gather_name + '/axis' }).create_node_with_data() gather = Gather(graph, { 'name': gather_name, 'need_shape_inference': True }).create_node_with_data([const, data_node, axis_const]) attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy() graph.add_edge(gather.id, op_node.id, **attrs) graph.remove_edge(data_node.id, op_node.id) op_node['need_shape_inference'] = True
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] gather_node = Gather( graph, dict(name=node.id + '/embedding_', symbol_dict={'name': node.id + '/embedding_'})).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() node.in_port(0).get_connection().set_destination( gather_node.in_port(1)) node.in_port(1).get_connection().set_destination( gather_node.in_port(0)) axis_const.out_port(0).connect(gather_node.in_port(2)) node.out_port(0).get_connection().set_source(gather_node.out_port(0))
def extract(cls, node): Gather.update_node_stat(node, {'batch_dims': node.pb.attr['batch_dims'].i}) return cls.enabled
def find_and_replace_pattern(self, graph: Graph): # 1. Inserting Gather to N*C format on constant shape paths # - Search for Shape ops # - Inserting Gather after them in case of [4] or [5] output shape shape_ops = graph.get_op_nodes(op='ShapeOf') constant_shape_paths = set() gather_inserted = [] for shape in shape_ops: output_port = shape.in_port(0).get_source() if is_output_data_in_correct_layout(output_port.node, output_port.idx): continue shape_of_shape_op_output = shape.out_node().shape if np.array_equal(shape_of_shape_op_output, [4]): index = np.array([0, 2, 3, 1]) elif np.array_equal(shape_of_shape_op_output, [5]): index = np.array([0, 2, 3, 4, 1]) else: continue const = Const(graph, {'value': index}).create_node() gather = Gather(graph, { 'name': shape.name + '/GatherNCHWtoNHWC' }).create_node() shape.out_port(0).get_connection().set_source(gather.out_port(0)) shape.out_port(0).connect(gather.in_port(0)) const.out_port(0).connect(gather.in_port(1)) constant_shape_paths.add(gather.id) gather_inserted.append(gather.id) # 2. Inserting Gather to NC* format # - Search from Shape ops found in previous step for nodes without value that are n-th children of Shape op # * MO can not propagate value, there is data path # - Inserting Gather on ports which comes from operations in `constant_shape_paths` list constant_shape_ends = [] for shape in shape_ops: constant_shape_ends.extend( self.search_of_constant_path_end(graph, node_name=shape.id, visited=constant_shape_paths)) for end in constant_shape_ends: node = Node(graph, end) in_ports = [ in_port for in_port in node.in_ports().values() if in_port.get_source().node.id in constant_shape_paths ] for in_port in in_ports: shape = in_port.data.get_shape() if np.array_equal(shape, [4]): index = np.array([0, 3, 1, 2]) elif np.array_equal(shape, [5]): index = np.array([0, 2, 3, 4, 1]) else: continue const = Const(graph, {'value': np.array(index)}).create_node() gather = Gather(graph, { 'name': node.name + '/GatherNHWCtoNCHW' }).create_node() in_port.get_source().connect(gather.in_port(0)) in_port.get_connection().set_source(gather.out_port(0)) const.out_port(0).connect(gather.in_port(1))
def extract(node): attrs = {} Gather.update_node_stat(node, attrs) return __class__.enabled
def extract(cls, node): Gather.update_node_stat(node, {}) return cls.enabled