def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): input = match['input'] lstm = match['lstm'] params = match['params'].value.copy() hidden_state = match['hidden_state'] cell_state = match['cell_state'] hidden_state_edge_attrs = deepcopy(graph.get_edge_data(hidden_state.id, lstm.id)[0]) cell_state_edge_attrs = deepcopy(graph.get_edge_data(cell_state.id, lstm.id)[0]) graph.remove_edge(match['params'].id, lstm.id) graph.remove_edge(match['hidden_state'].id, lstm.id) graph.remove_edge(match['cell_state'].id, lstm.id) self.repack_weights(graph, input, lstm, params) reshape = Reshape(graph, dict(dim=[lstm.in_node(0).shape[0], lstm.hidden_size])) if len(lstm.in_nodes()) > 2: hidden_state_edge_attrs['in'] = 3 new_init_h = reshape.create_node_with_data([hidden_state], attrs=dict(name=lstm.name + '/HiddenStateResize')) graph.add_edge(new_init_h.id, lstm.id, **hidden_state_edge_attrs) if len(lstm.in_nodes()) > 3: cell_state_edge_attrs['in'] = 4 new_init_c = reshape.create_node_with_data([cell_state], attrs=dict(name=lstm.name + '/CellStateResize')) graph.add_edge(new_init_c.id, lstm.id, **cell_state_edge_attrs)
def extract(cls, node): attrs = { 'op': __class__.op, 'dim': node.module.shape, } Reshape.update_node_stat(node, attrs) return cls.enabled
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): """ Need to find the pattern: SoftmaxActivation -> DetectionOutput DetectionOutput in IE expects flattened input from SoftMax, that is why there is the need to add Flatten layer Parameters ---------- graph : nx.MultiDiGraph Graph with loaded model. match : dict Patterns which were found in graph structure. """ softmax_activation = match['softmax_activation'] multi_box_detection = match['multi_box_detection'] softmax_activation['axis'] = -1 edge_data = graph.get_edge_data(softmax_activation.id, multi_box_detection.id) out_port = edge_data[0]['out'] in_port = edge_data[0]['in'] graph.remove_edge(softmax_activation.id, multi_box_detection.id) symbol_node = dict( op='Flatten', name=multi_box_detection.name + '/Reshape_', dim=[0, -1], axis=1, end_axis=-1 ) new_reshape_op = Reshape(graph, {'symbol_dict': symbol_node}) new_reshape_node = new_reshape_op.create_node([softmax_activation]) new_reshape_node['dim'] = [0, -1] create_edge(new_reshape_node, multi_box_detection, in_port=in_port, out_port=out_port)
def replace_pattern(graph: Graph, match: dict): select = match['op'] if select.has_valid('format') and select['format'] == 'tf': condition = select.in_node(0) input_1 = select.in_node(1) input_2 = select.in_node(2) assert np.array_equal(input_1.shape, input_2.shape) if len(condition.shape) == 1 and len(input_1.shape) > 1: new_shape = np.array([0] + [1] * (len(input_1.shape) - 1), dtype=np.int64) reshape_shape_const = Const(graph, { 'name': select.name + '/Reshape/Dim/', 'value': new_shape }).create_node() unsqueeze_op = Reshape( graph, dict(name=select.name + '/Broadcast/')).create_node(inputs=[condition]) reshape_shape_const.out_port( 0).get_connection().set_destination( unsqueeze_op.in_port(1)) select.in_port(0).disconnect() select.in_port(0).get_connection().set_source( unsqueeze_op.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] node.is_training = False shape = node.in_port(1).data.get_shape() assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format(node.name) bn_mean = Const(graph, {'name': node.name + '/mean', 'value': np.zeros(shape, dtype=np.float32), 'override_output_shape': True}).create_node() bn_std = Const(graph, {'name': node.name + '/std', 'value': np.ones(shape, dtype=np.float32), 'override_output_shape': True}).create_node() node.in_port(3).get_connection().set_source(bn_mean.out_port(0)) node.in_port(4).get_connection().set_source(bn_std.out_port(0)) # save the original shape original_shape = Shape(graph, {'name': node.in_port(0).get_source().node.soft_get('name')}).create_node() original_shape.in_port(0).connect(node.in_port(0).get_source()) mvn = MVN(graph, {'name': node.name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'override_output_shape': True}).create_node() node.in_port(0).get_connection().insert_node(mvn) reshape_4d = create_op_node_with_second_input(graph, Reshape, int64_array([1, -1, 0, 0]), {'override_output_shape': True, 'name': node.soft_get('name') + '/fused_batch_and_channels'}) mvn.in_port(0).get_connection().insert_node(reshape_4d) # restore original shape reshape_back = Reshape(graph, {'name': mvn.soft_get('name') + '/restore_shape', 'override_output_shape': True}).create_node() reshape_back.in_port(1).connect(original_shape.out_port(0)) mvn.out_port(0).get_connection().insert_node(reshape_back)
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) assert node.has_valid( 'axis' ), 'The node "{}" does not have mandatory attribute "axis"'.format( node_name) flatten_node = FlattenONNX(graph, { 'name': node_name + '/FlattenONNX_', 'axis': node.axis }).create_node() shape_node = Shape(graph, { 'name': node_name + '/ShapeOf_' }).create_node() logsoftmax_node = LogSoftmax(graph, { 'name': node_name + '/LogSoftmax_', 'axis': 1 }).create_node() reshape_node = Reshape(graph, {}).create_node() rename_nodes([(node, node_name + '/delete'), (reshape_node, node_name)]) shape_node.out_port(0).connect(reshape_node.in_port(1)) logsoftmax_node.out_port(0).connect(reshape_node.in_port(0)) flatten_node.out_port(0).connect(logsoftmax_node.in_port(0)) source = node.in_port(0).get_source() flatten_node.in_port(0).connect(source) shape_node.in_port(0).connect(source) return [reshape_node.id]
def replace_op(self, graph: nx.MultiDiGraph, node: Node): # reshape tensor with batch indices to 2d unsqueeze_op = Unsqueeze( graph, {'unsqueeze_dims': np.array([1], dtype=np.int64)}) unsqueeze_node = unsqueeze_op.create_node([node.in_node(2)]) concat_op = Concat( graph, { 'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes' }) concat_node = concat_op.create_node([unsqueeze_node, node.in_node(1)]) # do not remove edge with crop_size because it is needed in the partial infer graph.remove_edge(node.in_node(1).id, node.id) # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects # coordinates in the XYXY layout, so convolution is added here to swap coordinates swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates( graph, concat_node, 5) # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift reshape_2d_op = Reshape(graph, dict(dim=np.array([-1, 5]))) reshape_2d_node = reshape_2d_op.create_node( [swapped_box_coordinates_node], dict(name='reshape_2d_')) create_edge(reshape_2d_node, node, 0, 1) # do not replace any output edge return []
def test_reshape_infer(self, input_value, input_shape, output_shape, ref_value, ref_shape): graph = build_graph( nodes_attributes, [('input', 'data'), ('data', 'reshape'), ('output_shape', 'output_shape_data'), ('output_shape_data', 'reshape'), ('reshape', 'reshape_out')], { 'data': { 'shape': input_shape, 'value': input_value }, 'output_shape': { 'value': output_shape, 'shape': output_shape.shape }, 'output_shape_data': { 'value': output_shape, 'shape': output_shape.shape }, }) node = Node(graph, 'reshape') Reshape.infer(node) if ref_value is not None: self.assertTrue( strict_compare_tensors( node.out_port(0).data.get_value(), shape_array(ref_value))) self.assertTrue( strict_compare_tensors( node.out_port(0).data.get_shape(), shape_array(ref_shape)))
def add_reshape_for_new(graph: Graph, ss_node): log.info("StridedSlice op with new axis mask '{}' has been detected".format(ss_node.id)) node = ss_node if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1: return shape_out = node.out_node().shape dim = shape_out.copy() ss_shape = [] for i in range(0, len(node['new_axis_mask'])): if not node['new_axis_mask'][i]: ss_shape.append(shape_out[i]) else: node['new_axis_mask'][i] = 0 out_node = node.out_node(0) # insert data node for StridedSlice data_node = Op._create_data_node(graph, node.name + "/Reshape_new_data", {'shape': ss_shape}) attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0]) graph.remove_edge(node.id, out_node.id) graph.add_edge(node.id, data_node.id, **attrs) # insert Reshape reshape = Reshape(graph, dict(name=node.name + "/Reshape_new", dim=np.array(dim, dtype=np.int64))) reshape.create_node_with_data([data_node], reshape.attrs, data_nodes=[out_node])
def unsqueeze_num_directions(graph: Graph, match: dict): """ Assuming considered LSTM/GRU/RNN node should has num_directions in output shape and add Reshape to match it. """ rnn_layer = match['rnn_layer'] # num_directions is at 1st position in output shape, and in 0st position in hidden and cell states # please refer to docs in this transform direction_dim = [1, 0, 0] # index of dimension with direction index for i in rnn_layer.out_nodes(): old_data_node = rnn_layer.out_node(i) old_shape = old_data_node.shape.copy() new_shape = np.delete(old_shape, direction_dim[i]) data = Op._create_data_node(graph, name=rnn_layer.name + '/Out/{}/'.format(i), attrs={'shape': new_shape}) graph.remove_edge(rnn_layer.id, old_data_node.id) graph.add_edge(rnn_layer.id, data.id, key=0, out=i) reshape = Reshape(graph, dict(dim=old_shape)) reshape.create_node_with_data( [data], dict(name=rnn_layer.name + '/SqueezeNumDirections/{}'.format(i)), data_nodes=[old_data_node])
def find_and_replace_pattern(self, graph: Graph): for roll_node in graph.get_op_nodes(op='Roll'): if not roll_node.in_port(2).disconnected(): return node_name = roll_node.soft_get('name', roll_node.id) # reshape to 1d tensor reshape_to_1d = create_op_node_with_second_input( graph, Reshape, int64_array([-1]), {'name': node_name + '/reshape'}) roll_node.in_port(0).get_connection().insert_node(reshape_to_1d) # add zero const as axes input to roll const_zero = Const(graph, { 'value': int64_array([0]), 'name': node_name + '/axes' }).create_node() const_zero.out_port(0).connect(roll_node.in_port(2)) # reshape to original shape shape_of = Shape(graph, { 'name': node_name + '/shape_of' }).create_node() roll_node.in_port(0).get_connection().add_destination( shape_of.in_port(0)) reshape_to_orig_shape = Reshape(graph, {}).create_node() rename_nodes([(roll_node, node_name + '/roll'), (reshape_to_orig_shape, node_name)]) shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1)) roll_node.out_port(0).get_connection().insert_node( reshape_to_orig_shape)
def extract(cls, node): dim = onnx_attr(node, 'shape', 'ints', None) if dim is not None: dim = np.array(dim, dtype=np.int64) Reshape.update_node_stat(node, {'dim': dim}) else: Reshape.update_node_stat(node) return cls.enabled
def add_output_reshape(graph: Graph, match: dict): """ Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format. """ lstm = match['rnn_layer'] input = match['input'] if not lstm.has_num_directions: return old_data_node = lstm.out_node(0) num_directions = 2 if lstm.direction in ['bidirectional'] else 1 mxnet_shape = lstm.out_node(0).shape.copy() if lstm.batch_dim == 0: mo_shape = np.array([ input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim], lstm.hidden_size ], dtype=np.int64) else: mo_shape = np.array([ input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim], lstm.hidden_size ], dtype=np.int64) if lstm.has_num_directions: mo_shape = np.insert(mo_shape, 1, np.int64(num_directions)) lstm_name = lstm.soft_get('name', lstm.id) new_data = Op._create_data_node(graph, name=lstm_name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape}) graph.remove_edge(lstm.id, old_data_node.id) graph.add_edge(lstm.id, new_data.id, key=0, out=0) # Add Transpose permute_order = Const( graph, { 'name': lstm_name + '/Transpose_mxnet_order', 'value': int64_array([0, 2, 1, 3]) }).create_node_with_data() permute_data = Transpose(graph, { 'name': lstm_name + '/Transpose_mxnet/' }).create_node_with_data([new_data, permute_order]) # Add Reshape reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'}) reshape_dim_data = Const(graph, { 'name': lstm_name + '/Reshape_mxnet_dim', 'value': mxnet_shape }).create_node_with_data() reshape.create_node_with_data([permute_data, reshape_dim_data], dict(), data_nodes=[old_data_node])
def convert_fft_to_dft(self, graph: Graph, mx_fft: Node): mx_fft_name = mx_fft.soft_get('name', mx_fft.id) unsqueeze_node = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array([-1])}, {'name': mx_fft_name + '/Unsqueeze'}) rank_node = Rank(graph, {'name': mx_fft_name + '/Rank'}).create_node() mx_fft_connection = mx_fft.in_port(0).get_connection() mx_fft_connection.set_destination(unsqueeze_node.in_port(0)) mx_fft_connection.get_source().connect(rank_node.in_port(0)) add_node = create_op_with_const_inputs(graph, Add, {1: int64_array(1)}, {'name': mx_fft_name + '/Add'}, rank_node) broadcast_node1 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/Pad_broadcast'}) add_node.out_port(0).connect(broadcast_node1.in_port(1)) scatter_node = create_op_with_const_inputs( graph, ScatterUpdate, { 2: int64_array(1), 3: int64_array(0) }, {'name': mx_fft_name + '/ScatterUpdate'}) broadcast_node1.out_port(0).connect(scatter_node.in_port(0)) rank_node.out_port(0).connect(scatter_node.in_port(1)) pad_node = Pad(graph, { 'name': mx_fft_name + '/Pad', 'mode': 'constant' }).create_node([unsqueeze_node, broadcast_node1, scatter_node]) dft_node = create_op_with_const_inputs( graph, DFT, {1: int64_array([-1])}, { 'name': mx_fft_name + '/DFT', 'in_ports_count': 2 }, pad_node) sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)}, {'name': mx_fft_name + '/Sub'}) rank_node.out_port(0).connect(sub_node.in_port(0)) broadcast_node2 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/Reshape_broadcast'}) sub_node.out_port(0).connect(broadcast_node2.in_port(1)) concat_node = create_op_with_const_inputs( graph, Concat, {1: int64_array([-1, 2])}, { 'name': mx_fft_name + '/New_shape', 'in_ports_count': 2, 'axis': 0 }, broadcast_node2) reshape_node = Reshape(graph, {}).create_node([dft_node, concat_node]) mx_fft.out_port(0).get_connection().set_source( reshape_node.out_port(0)) rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'), (reshape_node, mx_fft_name)])
def find_and_replace_pattern(self, graph: nx.MultiDiGraph): data_nodes = [ Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data' ] for node in data_nodes: # Get all requested shapes for current node # This mapping will contain pairs like {shape:[list of consumers nodes]} mapping = {} for consumer in node.out_nodes(): edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] if 'new_shape' in edge_attrs: if np.array_equal(edge_attrs['new_shape'], node.shape): continue new_shape = tuple([x for x in edge_attrs['new_shape']]) if not new_shape in mapping: mapping.update({new_shape: [consumer]}) else: mapping[new_shape].append(consumer) if node.has_valid('value'): # Check that requested shape are the same # In case if they are different, we duplicate them for shape_key in mapping.keys(): shape = list(shape_key) new_value = np.reshape(node.value, shape) node_copy = Op.create_input_data_node( graph, node.id + '/copy', value=np.array(new_value)) for consumer in mapping[shape_key]: edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] del edge_attrs['new_shape'] # Remove edge from previous data node and connect new data node with its consumer graph.remove_edge(node.id, consumer.id) graph.add_edge(node_copy.id, consumer.id, **edge_attrs) else: # Insert Reshape layer between data node and consumer for shape_key in mapping.keys(): shape = list(shape_key) reshape = Reshape(graph, attrs={ 'dim': shape, 'name': 'EltwiseReshapeNormalization' }) reshape_data = reshape.create_node_with_data(inputs=[node]) # Iterate over consumers and reconnect them to Reshape layer output for consumer in mapping[shape_key]: edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] del edge_attrs['new_shape'] # Reconnect edge from original data node to Reshape output datanode graph.remove_edge(node.id, consumer.id) graph.add_edge(reshape_data.id, consumer.id, **edge_attrs)
def onnx_reshape_ext(node): ''' Extract ONNX Reshape op of different versions. Support both latest Reshape and Reshape-1. The first one has 2 arguments, Reshape-1 has one input and shape is coded in attribute. ''' dim = onnx_attr(node, 'shape', 'ints', None) if dim is not None: dim = np.array(dim, dtype=np.int64) Reshape.update_node_stat(node, {'dim': dim}) else: Reshape.update_node_stat(node) return node.graph.node[node.id]
def replace_pattern(self, graph: Graph, match: dict): if match['axis'].value is None or match['input'].shape is None: return dims = len(match['input'].shape) ones = np.ones(dims, dtype=np.int64) axis = np.array(match['axis'].value) axis = axis if axis.ndim != 0 else np.array([axis], dtype=np.int64) mean = graph.node[match['mean'].node] mean['stride'] = np.array(ones) # TODO: need to check axis with real layout spatial_dims = np.array(axis) mean['spatial_dims'] = spatial_dims mean['pad'] = np.zeros((dims, 2), np.int64) mean['pad_spatial_shape'] = np.array(mean['pad'][spatial_dims]) window = np.array(ones) window[spatial_dims] = match['input'].shape[spatial_dims] mean['window'] = window mean['TF_op'] = mean['op'] mean['op'] = 'AvgPool' mean['pool_method'] = 'avg' mean['rounding_type'] = 'ceil' mean['exclude_pad'] = 'true' mean['kernel_spatial'] = window[spatial_dims] graph.remove_edge(match['axis'].node, match['mean'].node) mean['permute_attrs'] = PermuteAttrs().update_attrs(attrs=[( 'pad', 'input:0'), ('stride', 'input:0'), ('window', 'input:0'), ('spatial_dims', 'input:0')]) if match['mean'].keep_dims == False: output = match['mean'].out_node() pool_node = match['mean'] # Keep dims for AvgPool shape = np.array(output.shape) for idx in spatial_dims: shape = np.insert(shape, idx, 1) graph.remove_edge(pool_node.id, output.id) # Create new data for pool with all dims pool_data = Op.create_data_node(graph, pool_node, {'shape': np.array(shape)}) # Create and connect reshape node reshape_op = Reshape(graph, {'dim': np.array(output.shape)}) reshape_node = reshape_op.create_node( [pool_data], dict(name='Reshape_', permute_attrs=PermuteAttrs().update_attrs( attrs=[('dim', 'output:0')]))) graph.create_edge(reshape_node, output)
def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) dim = attrs.tuple("shape", int, None) reverse = attrs.bool("reverse", False) update_attrs = {'dim': int64_array(dim), 'reverse': reverse} for d in dim: if d in [-2, -3, -4] or reverse: MXReshape.update_node_stat(node, update_attrs) return cls.enabled # update the attributes of the node Reshape.update_node_stat(node, update_attrs) return cls.enabled
def broadcast_with_reshape(port): input_shape = input_port.data.get_shape() reshape_dims = np.zeros(len(input_shape), dtype=np.int64) for i in range(0, node.axis): reshape_dims[i] = 1 data_shape = port.data.get_shape() for i in range(node.axis, node.axis + len(data_shape)): reshape_dims[i] = data_shape[i - node.axis] for i in range(node.axis + len(data_shape), len(input_shape)): reshape_dims[i] = 1 reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node() port.get_connection().set_destination(reshape.in_port(0)) reshape.out_port(0).connect(port)
def add_convolution_to_swap_xy_coordinates(graph: nx.MultiDiGraph, input_node: Node, coordinates_size: int): """ The function add convolution node after the node 'input_node' to swap xy coordinates of the boxes produced by the node 'input_node'. It is expected that box coordinates are located in the fastest changing dimension of the 'input_node' output, i.e. the input tensor could be reshaped to [num_boxes, 4] or [num_boxes, 5]. If the size is 5, then the 0-th element for each of num_boxes blocks is not changed and element 1 is swapped with element 2, element 3 is swapped with element 4. This is the case when boxes coordinates are produced by the layer "Proposal". The exact amount of elements in each block is equal to the 'coordinates_size' parameter. :param graph: graph to operate on. :param input_node: node producing boxes coordinates. :param coordinates_size: integer value equal to 4 or 5. :return convolution node that swaps coordinates. """ # swap of input tensor with 4 or 5 numbers describing boxes are supported assert (coordinates_size in [4, 5]) input_reshape_4d_op = Reshape( input_node.graph, dict(dim=np.array([-1, 1, 1, coordinates_size]))) input_reshape_4d_node = input_reshape_4d_op.create_node( [input_node], dict(name=input_node.name + '/reshape_4d')) update_attrs(input_reshape_4d_node, 'shape_attrs', 'dim') if coordinates_size == 5: # zero indexed element is not box coordinate ("batch id" in case of Proposal) conv_filter_data = np.array( np.array([[[[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1], [0, 0, 0, 1, 0]]]], dtype=np.float32)) else: conv_filter_data = np.array( np.array( [[[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]]], dtype=np.float32)) conv_filter_const_op = Const(graph, dict(value=conv_filter_data)) conv_filter_const_node = conv_filter_const_op.create_node( [], dict(name=input_node.name + '/weights')) conv_op = Convolution( graph, { 'bias_addable': True, 'channel_dims': np.array([3]), 'batch_dims': np.array([0]), 'input_feature_channel': 2, 'output_feature_channel': 3, 'group': 1, 'layout': 'NHWC', }) return conv_op.create_node([input_reshape_4d_node, conv_filter_const_node], dict(name=input_node.name + "/conv"))
def generate_sub_graph(self, graph: Graph, match: SubgraphMatch): # IE DetectionOutput layer consumes flattened confidences and locations tensors. # That is why we add reshapes before them. locs_node = match.single_input_node(0) conf_node = match.single_input_node(1) prior_boxes_node = match.single_input_node(2) locs_out_nodes = locs_node[0].out_nodes() assert len(locs_out_nodes) == 1 locs_out_node = locs_out_nodes[list(locs_out_nodes.keys())[0]] assert locs_out_node.op == "Result", locs_out_node.op graph.remove_node(locs_out_node.id) conf_out_nodes = conf_node[0].out_nodes() assert len(conf_out_nodes) == 1 conf_out_node = conf_out_nodes[list(conf_out_nodes.keys())[0]] assert conf_out_node.op == "Result", conf_out_node.op graph.remove_node(conf_out_node.id) # reshape operation to flatten confidence tensor const = Const(graph, {'value': int64_array([0, -1])}).create_node() reshape_loc_node = Reshape(graph, {}).create_node( [locs_node, const], dict(name='DetectionOutput_Reshape_loc_')) # reshape operation to flatten confidence tensor reshape_conf_node = Reshape(graph, {}).create_node( [conf_node, const], dict(name='DetectionOutput_Reshape_conf_')) # remove the Result node after the priors node assert prior_boxes_node[0].out_node().op == "Result" graph.remove_node(prior_boxes_node[0].out_node().id) # reshape operation for prior boxes tensor const = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() reshape_priors_node = Reshape(graph, {}).create_node( [prior_boxes_node, const], dict(name='DetectionOutput_Reshape_priors_')) # create Detection Output node with three inputs: locations, confidences and prior boxes detection_output_op = DetectionOutput( graph, match.custom_replacement_desc.custom_attributes) detection_output_node = detection_output_op.create_node( [reshape_loc_node, reshape_conf_node, reshape_priors_node], dict(name=detection_output_op.attrs['type'] + '_')) PermuteAttrs.set_permutation(reshape_priors_node, detection_output_node, None) # create Output node to mark DetectionOutput as a graph output operation output_op = Result(graph) output_op.create_node([detection_output_node], dict(name='sink_')) return {}
def replace_pattern(graph: Graph, match: dict): flatten = match['reshape'] output_shape = np.copy(flatten.out_port(0).data.get_shape()) output_shape[0] = 0 reshape = Reshape(graph, dict(name=flatten.id)).create_node() dim = Const(graph, dict(name=flatten.id + '/DimData', value=output_shape)).create_node() flatten.in_port(0).get_connection().set_destination(reshape.in_port(0)) dim.out_port(0).connect(reshape.in_port(1)) flatten.out_port(0).get_connection().set_source(reshape.out_port(0)) reshape['force_precision_in_ports'] = {1: 'int64'}
def extract(node): attrs = get_mxnet_layer_attrs(node.symbol_dict) dim = attrs.tuple("shape", int, None) update_attrs = {'dim': np.array(dim)} for d in dim: if d in [-2, -3, -4]: log.error( 'The attribute "shape" of the operation "{}" contains value "{}" which is not supported.' .format(node.soft_get('name'), d)) return False # update the attributes of the node Reshape.update_node_stat(node, update_attrs) return __class__.enabled
def decompose_shuffle_channel(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') shape = Shape(graph, dict(name=name + '/InputShape')).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) # Reshape [input_batch, group, input_channels/group, -1] batch = node_to_get_batch_value(shape) group = Const( graph, dict(name=name + '/Rows', value=int64_array([node.group]))).create_node() const = Const(graph, dict(name=name + '/Const', value=int64_array([-1]))).create_node() input_channels = node_to_get_features_dimension_value(shape) output_channels = create_op_node_with_second_input( graph, Div, np.int64(node.group), {'name': name + '/Cols'}, input_node=input_channels) i_output_channels = Cast(graph, { 'name': output_channels.name + '/Convert', 'dst_type': np.int64 }).create_node() output_channels.out_port(0).connect(i_output_channels.in_port(0)) reshape_split_dim = new_shape_node_from_shape_nodes( [batch, group, i_output_channels, const]) reshape_split_node = Reshape( graph, dict(name=name + '/Reshape_split_')).create_node() reshape_split_dim.out_port(0).connect(reshape_split_node.in_port(1)) # Transpose(0, 2, 1, 3) transpose_node = create_op_node_with_second_input( graph, Transpose, int64_array([0, 2, 1, 3]), {'name': name + '/Transpose_'}, input_node=reshape_split_node) # Reshape back to input shape reshape_concat = Reshape(graph, dict(name=name)).create_node() rename_node(reshape_concat, name) shape.out_port(0).connect(reshape_concat.in_port(1)) transpose_node.out_port(0).connect(reshape_concat.in_port(0)) # Final connections node.in_port(0).get_connection().set_destination( reshape_split_node.in_port(0)) node.out_port(0).get_connection().set_source( reshape_concat.out_port(0))
def squeeze_initial_states(graph: Graph, match: dict): """ Squeeze input initial states of recurrent node to 2-D shape. """ hidden_init_port = 5 cell_init_port = 6 rnn_layer = match['rnn_layer'] # Add input ports to rnn_layer rnn_layer.add_sequence_of_ports(type='in', rng=range(7)) reshape = Reshape(graph, {}) assert hidden_init_port in rnn_layer.in_nodes() init_h = rnn_layer.in_node(hidden_init_port) edge_attrs = deepcopy(graph.get_edge_data(init_h.id, rnn_layer.id)[0]) edge_attrs['in'] = hidden_init_port graph.remove_edge(init_h.id, rnn_layer.id) new_dim = int64_array([ rnn_layer.in_node(0).shape[rnn_layer.batch_dim], rnn_layer.hidden_size ]) reshape_dim_data = Const(graph, { 'name': rnn_layer.name + '/HiddenStateResizeDim', 'value': new_dim }).create_node_with_data() new_init_h = reshape.create_node_with_data([init_h, reshape_dim_data], dict(name=rnn_layer.name + '/HiddenStateResize')) graph.add_edge(new_init_h.id, rnn_layer.id, **edge_attrs) if rnn_layer.op == 'LSTM': assert cell_init_port in rnn_layer.in_nodes() init_c = rnn_layer.in_node(cell_init_port) edge_attrs = deepcopy( graph.get_edge_data(init_c.id, rnn_layer.id)[0]) edge_attrs['in'] = cell_init_port graph.remove_edge(init_c.id, rnn_layer.id) reshape_dim_data = Const(graph, { 'name': rnn_layer.name + '/CellStateResizeDim', 'value': new_dim }).create_node_with_data() new_init_c = reshape.create_node_with_data( [init_c, reshape_dim_data], dict(name=rnn_layer.name + '/CellStateResize')) graph.add_edge(new_init_c.id, rnn_layer.id, **edge_attrs)
def replace_pattern(self, graph: Graph, match: dict): conv = match['conv'] assert len(conv.out_nodes()) == 1, "Convolution operation {} should have 1 output data node".format(conv.id) out_data = conv.out_node() assert out_data.has_valid('shape'), 'Output shape is undefined for {} in back phase'.format(conv.id) out_shape = out_data.shape if out_shape.size != 3: return assert len(conv.in_nodes()) >= 1, "Convolution operation {} should have more than 1 input data node".format( conv.id) inp_data = conv.in_node() assert inp_data.has_valid('shape'), 'Input shape is undefined for {} in back phase'.format(conv.id) inp_shape = inp_data.shape new_inp_shape = np.insert(inp_shape, 2, 1) # setting to None to be overwritten by infer function conv.kernel_spatial_idx = None conv.spatial_dims = None # inserting fake H dimension conv.dilation = np.insert(conv.dilation, 2, 1) conv.kernel_spatial = np.append([1], conv.kernel_spatial) conv.pad = np.insert(conv.pad, 2, [0, 0], axis=0) conv.stride = np.insert(conv.stride, 2, 1) weights_node = conv.in_node(1) weights_node.value = np.reshape(weights_node.value, np.insert(weights_node.value.shape, 2, 1)) weights_node.shape = np.array(weights_node.value.shape, dtype=np.int64) reshape = Reshape(graph, {'name': conv.name + '/reshape'}).create_node() reshape_dim = Const(graph, {'value': new_inp_shape, 'name': reshape.id + '/Dim'}).create_node() conv.in_port(0).get_connection().insert_node(reshape) reshape.in_port(1).connect(reshape_dim.out_port(0)) reshape_back = Reshape(graph, {'name': conv.name + '/reshape_back'}).create_node() reshape_back_dim = Const(graph, {'value': out_shape, 'name': reshape.id + '/Dim'}).create_node() conv.out_port(0).get_connection().insert_node(reshape_back) reshape_back.in_port(1).connect(reshape_back_dim.out_port(0)) # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour reshape_dim.infer(reshape_dim) reshape.infer(reshape) conv.infer(conv)
def add_output_reshape(graph: Graph, match: dict): """ Since MXNet Y output shape is [batch_size, seq_len, hidden_size * num_directions] we need to add reshape from above common format [batch_size, num_directions, seq_len, hidden_size] to MXNet format. """ lstm = match['rnn_layer'] input = match['input'] if not lstm.has_num_directions: return old_data_node = lstm.out_node(0) num_directions = 2 if lstm.direction in ['bidirectional'] else 1 mxnet_shape = lstm.out_node(0).shape.copy() if lstm.batch_dim == 0: mo_shape = np.array([ input.shape[lstm.batch_dim], input.shape[lstm.sequence_dim], lstm.hidden_size ], dtype=np.int64) else: mo_shape = np.array([ input.shape[lstm.sequence_dim], input.shape[lstm.batch_dim], lstm.hidden_size ], dtype=np.int64) if lstm.has_num_directions: mo_shape = np.insert(mo_shape, 1, np.int64(num_directions)) new_data = Op._create_data_node(graph, name=lstm.name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape}) graph.remove_edge(lstm.id, old_data_node.id) graph.add_edge(lstm.id, new_data.id, key=0, out=0) # Add Permute permute_order = np.array([0, 2, 1, 3], dtype=np.int64) permute = Permute(graph, dict(order=permute_order)) permute_data = permute.create_node_with_data([new_data], dict(name=lstm.name + '/Permute_mxnet/')) # Add Reshape reshape = Reshape(graph, dict(dim=mxnet_shape)) reshape.create_node_with_data([permute_data], dict(name=lstm.name + '/Reshape_mxnet/'), data_nodes=[old_data_node])
def extract(cls, node): param = node.pb.reshape_param if param.axis != 0: log.error('The operation "Reshape" has attribute "axis" with unsupported value "{}"'.format(param['axis'])) return False if param.num_axes != -1: log.error('The operation "Reshape" has attribute "num_axes" with unsupported value "{}"'.format( param['num_axes'])) return False Reshape.update_node_stat(node, { 'dim': list(param.shape.dim), }) return cls.enabled
def replace_sub_graph(self, graph: Graph, match: dict): if not check_applicability(match): return reshape = match['reshape'] div_name = match['division'].name input_shape = Shape(graph, dict(name=div_name + '/shape/MVN_T_')).create_node() shape_of_reshape = reshape.in_port(1).get_connection().get_source().node.value c1, c2 = shape_of_reshape[1], shape_of_reshape[2] c = c1 * c2 new_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([0, 0, 0, c1, c2]), dict(name=div_name + '/first_reshape/MVN_T_')) permute_order = int64_array([0, 1, 2, 4, 3]) first_permute = create_op_node_with_second_input(graph, Transpose, permute_order, dict(name=div_name + '/first_permute/MVN_T_'), new_reshape) add = match['add'] variance = match['variance'] eps_port_num = 0 if add.in_port(0).get_connection().get_source().node.id != variance.id else 1 eps = add.in_port(eps_port_num).get_connection().get_source().node mvn_node = create_op_with_const_inputs(graph, MVN, {1: int64_array([1, 2, 3])}, dict(name=div_name + '/MVN/MVN_T_', eps=eps.value, normalize_variance=1, eps_mode='inside_sqrt')) first_permute.out_port(0).connect(mvn_node.in_port(0)) second_permute = create_op_node_with_second_input(graph, Transpose, permute_order, dict(name=div_name + '/second_permute/MVN_T_'), mvn_node) new_reshape2 = Reshape(graph, dict(name=div_name + '/second_reshape/MVN_T_')).create_node() second_permute.out_port(0).connect(new_reshape2.in_port(0)) gamma_val = np.reshape(match['gamma_identity'].in_port(0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_mul = create_op_node_with_second_input(graph, Mul, gamma_val, dict(name=match['mul'].name + '/MVN_T_'), new_reshape2) beta_val = np.reshape(match['beta_identity'].in_port(0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_add2 = create_op_node_with_second_input(graph, Add, beta_val, dict(name=match['add2'].name + '/MVN_T_'), new_mul) transpose_connection = match['transpose'].in_port(0).get_connection() before_transpose = transpose_connection.get_source().node transpose_connection.set_destination(new_reshape.in_port(0)) input_shape.out_port(0).connect(new_reshape2.in_port(1)) before_transpose.out_port(0).connect(input_shape.in_port(0)) match['transpose2'].out_port(0).get_connection().set_source(new_add2.out_port(0))
def replace_op(self, graph: Graph, node: Node) -> list: input_node = node.in_node(0) input_reshape_node = Reshape(graph, { 'name': 'Reshape/' + node.name, 'infer': Reshape.kaldi_infer }).create_node([input_node]) pooling_node = Pooling(graph, graph.node[node.id]).create_node( [input_reshape_node]) output_reshape_node = Reshape(graph, { 'name': node.name + '/Reshape/', 'infer': Reshape.kaldi_infer }).create_node([pooling_node]) return [output_reshape_node.id]