def find_and_replace_pattern(self, graph: Graph): layout = graph.graph['layout'] for eltwise_op_node in graph.get_op_nodes(is_eltwise=True): out_shape = eltwise_op_node.out_port().data.get_shape() if 4 <= len(out_shape) <= 5: out_features = out_shape[get_features_dim( layout, len(out_shape))] for port, node in eltwise_op_node.in_nodes().items(): if len(node.shape) != len(out_shape) and len( node.shape) == 1 and out_features == node.shape[0]: new_shape = shape_for_layout( layout, batch=1, features=out_features, height=1, width=1, depth=1 if len(out_shape) == 5 else None) dim_const = Const(graph, { 'value': new_shape, 'name': node.id + '/Dim' }).create_node() reshape_op = Reshape(graph, attrs={ 'dim': new_shape, 'name': node.id + '/Broadcast' }).create_node() eltwise_op_node.in_port(port).get_source( ).node.out_port(0).get_connection().set_destination( reshape_op.in_port(0)) reshape_op.in_port(1).connect(dim_const.out_port(0)) reshape_op.out_port(0).connect( eltwise_op_node.in_port(port))
def squeeze_initial_states(graph: Graph, match: dict): """ Squeeze input initial states of recurrent node to 2-D shape. """ hidden_init_port = 5 cell_init_port = 6 rnn_layer = match['rnn_layer'] # Add input ports to rnn_layer rnn_layer.add_sequence_of_ports(type='in', rng=range(7)) rnn_layer_name = rnn_layer.soft_get('name', rnn_layer.id) assert hidden_init_port in rnn_layer.in_nodes() hidden_size = rnn_layer.hidden_size shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node() rnn_layer.in_port(0).get_source().connect(shape.in_port(0)) batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim])) new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]), op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim', in_ports_count=2, axis=0), input_node=batch) reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node() new_dim.out_port(0).connect(reshape_h.in_port(1)) rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h) if rnn_layer.op == 'LSTM': assert cell_init_port in rnn_layer.in_nodes() reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node() new_dim.out_port(0).connect(reshape_c.in_port(1)) rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) assert node.has_valid( 'axis' ), 'The node "{}" does not have mandatory attribute "axis"'.format( node_name) flatten_node = FlattenONNX(graph, { 'name': node_name + '/FlattenONNX_', 'axis': node.axis }).create_node() shape_node = Shape(graph, { 'name': node_name + '/ShapeOf_' }).create_node() logsoftmax_node = LogSoftmax(graph, { 'name': node_name + '/LogSoftmax_', 'axis': 1 }).create_node() reshape_node = Reshape(graph, {}).create_node() rename_nodes([(node, node_name + '/delete'), (reshape_node, node_name)]) shape_node.out_port(0).connect(reshape_node.in_port(1)) logsoftmax_node.out_port(0).connect(reshape_node.in_port(0)) flatten_node.out_port(0).connect(logsoftmax_node.in_port(0)) source = node.in_port(0).get_source() flatten_node.in_port(0).connect(source) shape_node.in_port(0).connect(source) return [reshape_node.id]
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name) axis = node.axis if axis == 0: dim = Const(graph, {'value': int64_array([1, -1])}).create_node() elif axis == 1: dim = Const(graph, {'value': int64_array([0, -1])}).create_node() else: shape = Shape(graph, {'name': name + '/input_shape'}).create_node() idxs = list(range(axis)) if axis > 0 else list(range(axis, 0)) axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs) first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True}) second_dims = Const(graph, {'value': int64_array([-1])}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) axis_shape_portion.out_port(0).connect(first_dims.in_port(0)) order_of_dims = [first_dims, second_dims] if axis > 0 else [second_dims, first_dims] dim = new_shape_node_from_shape_nodes(order_of_dims) reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node() reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] node.is_training = False shape = node.in_port(1).data.get_shape() assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format(node.name) bn_mean = Const(graph, {'name': node.name + '/mean', 'value': np.zeros(shape, dtype=np.float32), 'override_output_shape': True}).create_node() bn_std = Const(graph, {'name': node.name + '/std', 'value': np.ones(shape, dtype=np.float32), 'override_output_shape': True}).create_node() node.in_port(3).get_connection().set_source(bn_mean.out_port(0)) node.in_port(4).get_connection().set_source(bn_std.out_port(0)) # save the original shape original_shape = Shape(graph, {'name': node.in_port(0).get_source().node.soft_get('name')}).create_node() original_shape.in_port(0).connect(node.in_port(0).get_source()) mvn = MVN(graph, {'name': node.name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'override_output_shape': True}).create_node() node.in_port(0).get_connection().insert_node(mvn) reshape_4d = create_op_node_with_second_input(graph, Reshape, int64_array([1, -1, 0, 0]), {'override_output_shape': True, 'name': node.soft_get('name') + '/fused_batch_and_channels'}) mvn.in_port(0).get_connection().insert_node(reshape_4d) # restore original shape reshape_back = Reshape(graph, {'name': mvn.soft_get('name') + '/restore_shape', 'override_output_shape': True}).create_node() reshape_back.in_port(1).connect(original_shape.out_port(0)) mvn.out_port(0).get_connection().insert_node(reshape_back)
def replace_sub_graph(self, graph: Graph, match: dict): node = match['mxreshape'] input_index = 0 reshape_index = 0 shape_node = Shape(graph, dict(name=node.id + '/ShapeMXReshape')).create_node() shape_node.in_port(0).connect(node.in_port(0).get_source()) output_dims_nodes = [] for d in node.dim: if reshape_index < len(node.dim): input_index, reshape_index, output_dims_nodes = self.resolve( input_index, reshape_index, node.dim, shape_node, output_dims_nodes) concat_node = Concat( shape_node.graph, dict(name=shape_node.id + '/ConcatMXReshape_', axis=0, in_ports_count=len(output_dims_nodes))).create_node() for in_port_index, dim_node in enumerate(output_dims_nodes): concat_node.in_port(in_port_index).connect(dim_node.out_port(0)) reshape_node = Reshape(graph, dict(name=node.id + '/Reshape_')).create_node() reshape_node.in_port(1).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['conv'] node_name = node.soft_get('name', node.id) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() shape = Cast( graph, { 'name': node_name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) shape.in_port(0).connect(i_shape.out_port(0)) N, H = node_to_get_shape_value_of_indices( shape, [0]), node_to_get_shape_value_of_indices(shape, [1]) div = create_op_with_const_inputs( graph, Div, {1: float_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, { 2: float_array([1]), 3: float_array([node.patch_stride]) }, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(N.out_port(0)) concat.in_port(1).connect(div.out_port(0)) reshape_pattern = Cast(graph, { 'name': node_name + '/to_int', 'dst_type': np.int64 }).create_node() concat.out_port(0).connect(reshape_pattern.in_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(reshape_pattern.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['proposal'] assert len(node.in_ports() ) == 3, "Proposal op must have exactly 3 input ports" im_info_shape = node.in_port(2).data.get_shape() assert im_info_shape is not None if np.array_equal(im_info_shape, [1, 6]): log.error( 'The model contains Proposal layer "{}" with input of shape [1, 6]. Inference Engine ' 'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). ' 'Elements with indices 4 and 5 will be ignored.'.format( node.soft_get('name', node.id)), extra={'is_warning': True}) begin = Const(graph, { 'value': np.array([0, 0], dtype=np.int32) }).create_node() end = Const(graph, { 'value': np.array([1, 3], dtype=np.int32) }).create_node() stride = Const(graph, { 'value': np.array([1, 1], dtype=np.int32) }).create_node() cropped_im_info = StridedSlice( graph, { 'name': 'cropped_im_info', 'begin_mask': int64_array([1, 1]), 'end_mask': int64_array([1, 1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]), 'override_output_shape': True, }).create_node() node.in_port(2).get_connection().insert_node(cropped_im_info) begin.out_port(0).connect(cropped_im_info.in_port(1)) end.out_port(0).connect(cropped_im_info.in_port(2)) stride.out_port(0).connect(cropped_im_info.in_port(3)) # update the im_info_shape so the next 'if' statement become true im_info_shape = int64_array([1, 3]) if np.array_equal(im_info_shape, [1, 3]) or np.array_equal( im_info_shape, [1, 4]): reshape = Reshape(graph, dict(name="im_info/Reshape")).create_node() const = Const(graph, dict(value=[im_info_shape[1]])).create_node() node.in_port(2).get_connection().set_destination( reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(node.in_port(2)) if node.has_port('out', 1) and not node.out_port(1).disconnected(): # This is the case when Proposal layer is used from extension, not from opset. # Setting version attribute is not recommended, this will be fixed after Proposal will be updated in IE. graph.node[node.id]['version'] = 'extension'
def decompose_shuffle_channel(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') shape = Shape(graph, dict(name=name + '/InputShape')).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) # Reshape [input_batch, group, input_channels/group, -1] batch = node_to_get_batch_value(shape) group = Const( graph, dict(name=name + '/Rows', value=int64_array([node.group]))).create_node() const = Const(graph, dict(name=name + '/Const', value=int64_array([-1]))).create_node() input_channels = node_to_get_features_dimension_value(shape) output_channels = create_op_node_with_second_input( graph, Div, np.int64(node.group), {'name': name + '/Cols'}, input_node=input_channels) i_output_channels = Cast(graph, { 'name': output_channels.name + '/Convert', 'dst_type': np.int64 }).create_node() output_channels.out_port(0).connect(i_output_channels.in_port(0)) reshape_split_dim = new_shape_node_from_shape_nodes( [batch, group, i_output_channels, const]) reshape_split_node = Reshape( graph, dict(name=name + '/Reshape_split_')).create_node() reshape_split_dim.out_port(0).connect(reshape_split_node.in_port(1)) # Transpose(0, 2, 1, 3) transpose_node = create_op_node_with_second_input( graph, Transpose, int64_array([0, 2, 1, 3]), {'name': name + '/Transpose_'}, input_node=reshape_split_node) # Reshape back to input shape reshape_concat = Reshape(graph, dict(name=name)).create_node() rename_node(reshape_concat, name) shape.out_port(0).connect(reshape_concat.in_port(1)) transpose_node.out_port(0).connect(reshape_concat.in_port(0)) # Final connections node.in_port(0).get_connection().set_destination( reshape_split_node.in_port(0)) node.out_port(0).get_connection().set_source( reshape_concat.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['pool'] if node.pool_step is None: node.stride = int64_array([1, 1, node.window[-1], node.window[-1]]) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] shape = Shape(graph, {}).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) split = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(0), 2: int64_array([1, -1]) }, {'out_ports_count': 2}, shape) node_pool_stride = Const(graph, { 'value': int64_array([node.pool_stride]) }).create_node() pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1])) pow_node.in_port(0).connect(node_pool_stride.out_port(0)) mul = Mul(graph, {}).create_node() mul.in_port(0).connect(split.out_port(1)) mul.in_port(1).connect(pow_node.out_port(0)) const_1 = Const(graph, {'value': int64_array([1])}).create_node() concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node() concat.in_port(0).connect(split.out_port(0)) concat.in_port(3).connect(mul.out_port(0)) concat.in_port(2).connect(const_1.out_port(0)) concat.in_port(1).connect(node_pool_stride.out_port(0)) reshape_in = Reshape(graph, { 'name': '/Reshape/' + node.name }).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node.name + '/Reshape/'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_pattern(graph: Graph, match: dict): flatten = match['reshape'] output_shape = np.copy(flatten.out_port(0).data.get_shape()) output_shape[0] = 0 reshape = Reshape(graph, dict(name=flatten.id)).create_node() dim = Const(graph, dict(name=flatten.id + '/DimData', value=output_shape)).create_node() flatten.in_port(0).get_connection().set_destination(reshape.in_port(0)) dim.out_port(0).connect(reshape.in_port(1)) flatten.out_port(0).get_connection().set_source(reshape.out_port(0)) reshape['force_precision_in_ports'] = {1: 'int64'}
def replace_pattern(graph: Graph, match: dict): node = match['op'] input_shape = node.in_port(0).data.get_shape() if len(input_shape) > 2: new_shape = Const(graph, { 'value': np.array([0, -1], dtype=np.int64) }).create_node() reshape = Reshape(graph, {}).create_node() source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) source.connect(reshape.in_port(0)) new_shape.out_port(0).connect(reshape.in_port(1)) new_shape.infer(new_shape) reshape.infer(reshape)
def replace_pattern(self, graph: Graph, match: dict): conv = match['conv'] assert len(conv.out_nodes()) == 1, "Convolution operation {} should have 1 output data node".format(conv.id) out_data = conv.out_node() assert out_data.has_valid('shape'), 'Output shape is undefined for {} in back phase'.format(conv.id) out_shape = out_data.shape if out_shape.size != 3: return assert len(conv.in_nodes()) >= 1, "Convolution operation {} should have more than 1 input data node".format( conv.id) inp_data = conv.in_node() assert inp_data.has_valid('shape'), 'Input shape is undefined for {} in back phase'.format(conv.id) inp_shape = inp_data.shape new_inp_shape = np.insert(inp_shape, 2, 1) # setting to None to be overwritten by infer function conv.kernel_spatial_idx = None conv.spatial_dims = None # inserting fake H dimension conv.dilation = np.insert(conv.dilation, 2, 1) conv.kernel_spatial = np.append([1], conv.kernel_spatial) conv.pad = np.insert(conv.pad, 2, [0, 0], axis=0) conv.stride = np.insert(conv.stride, 2, 1) weights_node = conv.in_node(1) weights_node.value = np.reshape(weights_node.value, np.insert(weights_node.value.shape, 2, 1)) weights_node.shape = np.array(weights_node.value.shape, dtype=np.int64) reshape = Reshape(graph, {'name': conv.name + '/reshape'}).create_node() reshape_dim = Const(graph, {'value': new_inp_shape, 'name': reshape.id + '/Dim'}).create_node() conv.in_port(0).get_connection().insert_node(reshape) reshape.in_port(1).connect(reshape_dim.out_port(0)) reshape_back = Reshape(graph, {'name': conv.name + '/reshape_back'}).create_node() reshape_back_dim = Const(graph, {'value': out_shape, 'name': reshape.id + '/Dim'}).create_node() conv.out_port(0).get_connection().insert_node(reshape_back) reshape_back.in_port(1).connect(reshape_back_dim.out_port(0)) # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour reshape_dim.infer(reshape_dim) reshape.infer(reshape) conv.infer(conv)
def replace_pattern(graph: Graph, match: dict): node = match['proposal'] assert len(node.in_ports() ) == 3, "Proposal op must have exactly 3 input ports" im_info_shape = node.in_port(2).data.get_shape() assert im_info_shape is not None if np.array_equal(im_info_shape, [1, 3]) or np.array_equal( im_info_shape, [1, 4]): reshape = Reshape(graph, dict(name="im_info/Reshape")).create_node() const = Const(graph, dict(value=[im_info_shape[1]])).create_node() node.in_port(2).get_connection().set_destination( reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(node.in_port(2))
def find_and_replace_pattern(self, graph: Graph): for roll_node in graph.get_op_nodes(op='Roll'): if not roll_node.in_port(2).disconnected(): return node_name = roll_node.soft_get('name', roll_node.id) # reshape to 1d tensor reshape_to_1d = create_op_node_with_second_input( graph, Reshape, int64_array([-1]), {'name': node_name + '/reshape'}) roll_node.in_port(0).get_connection().insert_node(reshape_to_1d) # add zero const as axes input to roll const_zero = Const(graph, { 'value': int64_array([0]), 'name': node_name + '/axes' }).create_node() const_zero.out_port(0).connect(roll_node.in_port(2)) # reshape to original shape shape_of = Shape(graph, { 'name': node_name + '/shape_of' }).create_node() roll_node.in_port(0).get_connection().add_destination( shape_of.in_port(0)) reshape_to_orig_shape = Reshape(graph, {}).create_node() rename_nodes([(roll_node, node_name + '/roll'), (reshape_to_orig_shape, node_name)]) shape_of.out_port(0).connect(reshape_to_orig_shape.in_port(1)) roll_node.out_port(0).get_connection().insert_node( reshape_to_orig_shape)
def replace_pattern(graph: Graph, match: dict): select = match['op'] if select.has_valid('format') and select['format'] == 'tf': condition = select.in_node(0) input_1 = select.in_node(1) input_2 = select.in_node(2) assert np.array_equal(input_1.shape, input_2.shape) if len(condition.shape) == 1 and len(input_1.shape) > 1: new_shape = np.array([0] + [1] * (len(input_1.shape) - 1), dtype=np.int64) reshape_shape_const = Const(graph, { 'name': select.name + '/Reshape/Dim/', 'value': new_shape }).create_node() unsqueeze_op = Reshape( graph, dict(name=select.name + '/Broadcast/')).create_node(inputs=[condition]) reshape_shape_const.out_port( 0).get_connection().set_destination( unsqueeze_op.in_port(1)) select.in_port(0).disconnect() select.in_port(0).get_connection().set_source( unsqueeze_op.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): if not check_applicability(match): return reshape = match['reshape'] div_name = match['division'].name input_shape = Shape(graph, dict(name=div_name + '/shape/MVN_T_')).create_node() shape_of_reshape = reshape.in_port(1).get_connection().get_source().node.value c1, c2 = shape_of_reshape[1], shape_of_reshape[2] c = c1 * c2 new_reshape = create_op_node_with_second_input(graph, Reshape, int64_array([0, 0, 0, c1, c2]), dict(name=div_name + '/first_reshape/MVN_T_')) permute_order = int64_array([0, 1, 2, 4, 3]) first_permute = create_op_node_with_second_input(graph, Transpose, permute_order, dict(name=div_name + '/first_permute/MVN_T_'), new_reshape) add = match['add'] variance = match['variance'] eps_port_num = 0 if add.in_port(0).get_connection().get_source().node.id != variance.id else 1 eps = add.in_port(eps_port_num).get_connection().get_source().node mvn_node = create_op_with_const_inputs(graph, MVN, {1: int64_array([1, 2, 3])}, dict(name=div_name + '/MVN/MVN_T_', eps=eps.value, normalize_variance=1, eps_mode='inside_sqrt')) first_permute.out_port(0).connect(mvn_node.in_port(0)) second_permute = create_op_node_with_second_input(graph, Transpose, permute_order, dict(name=div_name + '/second_permute/MVN_T_'), mvn_node) new_reshape2 = Reshape(graph, dict(name=div_name + '/second_reshape/MVN_T_')).create_node() second_permute.out_port(0).connect(new_reshape2.in_port(0)) gamma_val = np.reshape(match['gamma_identity'].in_port(0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_mul = create_op_node_with_second_input(graph, Mul, gamma_val, dict(name=match['mul'].name + '/MVN_T_'), new_reshape2) beta_val = np.reshape(match['beta_identity'].in_port(0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_add2 = create_op_node_with_second_input(graph, Add, beta_val, dict(name=match['add2'].name + '/MVN_T_'), new_mul) transpose_connection = match['transpose'].in_port(0).get_connection() before_transpose = transpose_connection.get_source().node transpose_connection.set_destination(new_reshape.in_port(0)) input_shape.out_port(0).connect(new_reshape2.in_port(1)) before_transpose.out_port(0).connect(input_shape.in_port(0)) match['transpose2'].out_port(0).get_connection().set_source(new_add2.out_port(0))
def convert_ifft_to_dft(self, graph: Graph, mx_fft: Node): mx_fft_name = mx_fft.soft_get('name', mx_fft.id) rank_node = Rank(graph, {'name': mx_fft_name + '/rank'}).create_node() sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)}, {'name': mx_fft_name + '/Sub'}) rank_node.out_port(0).connect(sub_node.in_port(0)) broadcast_node0 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/broadcast'}) sub_node.out_port(0).connect(broadcast_node0.in_port(1)) concat_node = create_op_with_const_inputs( graph, Concat, {1: int64_array([-1, 2])}, { 'name': mx_fft_name + '/new_shape', 'in_ports_count': 2, 'axis': 0 }, broadcast_node0) reshape_node = Reshape(graph, { 'name': mx_fft_name + '/reshape' }).create_node() concat_node.out_port(0).connect(reshape_node.in_port(1)) mx_fft_connection = mx_fft.in_port(0).get_connection() mx_fft_connection.set_destination(reshape_node.in_port(0)) mx_fft_connection.get_source().connect(rank_node.in_port(0)) dft_node = create_op_with_const_inputs( graph, IDFT, {1: int64_array([-1])}, { 'name': mx_fft_name + '/idft', 'in_ports_count': 2 }, reshape_node) split_node = create_op_with_const_inputs( graph, Split, {1: int64_array(-1)}, { 'name': mx_fft_name + '/split', 'num_splits': 2 }, dft_node) squeeze_node = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([-1])}, {}, split_node) mx_fft.out_port(0).get_connection().set_source( squeeze_node.out_port(0)) rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'), (squeeze_node, mx_fft_name)])
def replace_pattern(self, graph: Graph, match: dict): gather = match['GatherNd'] input_shape = gather.in_node(0).shape indices = gather.in_node(1).value if indices is None: # We can't do such special pass without indices value return # 0. All needed checks that we can replace GatherNd by Gather gather_idx = self.indices_check(indices, input_shape) if gather_idx is None: log.warning( 'Node {} with op=GatherNd can\'t be normalized to op=Gather.'. format(gather.name)) return # 1. Add Reshape and connect new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:])) reshape = Reshape(graph, { 'name': gather.name + '/Reshape_for_GatherNd/' }).create_node() reshape_const_node = Const(graph, { 'name': reshape.name + '/Dim', 'value': new_shape }).create_node() gather.in_port(0).get_connection().set_destination(reshape.in_port(0)) reshape.in_port(1).connect(reshape_const_node.out_port(0)) # 2. Change indices from Nd to 1d: new_indices = np.reshape( np.take(indices, indices=[gather_idx], axis=-1), [-1]) new_indices_const = Const(graph, dict(value=new_indices)).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() # 3. Create new Gather operation and reconnect all inputs/outputs new_gather = Gather(graph, { 'name': gather.name + '/NewGather/' }).create_node() reshape.out_port(0).connect(new_gather.in_port(0)) new_indices_const.out_port(0).connect(new_gather.in_port(1)) axis_const.out_port(0).connect(new_gather.in_port(2)) gather.out_port(0).get_connection().set_source(new_gather.out_port(0)) # 4. Remove old Gather node graph.remove_node(gather.id)
def replace_sub_graph(self, graph: Graph, match: dict): node = match['conv'] input_reshape_node = Reshape(graph, { 'name': '/Reshape/' + node.name, 'axis': 1, 'infer': Reshape.kaldi_infer }).create_node() output_reshape_node = Reshape(graph, { 'name': node.name + '/Reshape/', 'axis': 1, 'infer': Reshape.kaldi_infer }).create_node() # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(input_reshape_node.out_port(0)) input_reshape_node.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(output_reshape_node.out_port(0)) node.out_port(0).connect(output_reshape_node.in_port(0))
def broadcast_with_reshape(port): input_shape = input_port.data.get_shape() reshape_dims = np.zeros(len(input_shape), dtype=np.int64) for i in range(0, node.axis): reshape_dims[i] = 1 data_shape = port.data.get_shape() for i in range(node.axis, node.axis + len(data_shape)): reshape_dims[i] = data_shape[i - node.axis] for i in range(node.axis + len(data_shape), len(input_shape)): reshape_dims[i] = 1 reshape = Reshape(graph, dict(name=port.node.name + "/Broadcast_", dim=reshape_dims)).create_node() port.get_connection().set_destination(reshape.in_port(0)) reshape.out_port(0).connect(port)
def replace_sub_graph(self, graph: Graph, match: dict): mxreshape = match['op'] if not mxreshape.reverse: return shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node() forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardUnsqueeze')) forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node() forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardSqueeze')) reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node() shape_node.in_port(0).connect(mxreshape.in_port(0).get_source()) mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0)) forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0)) forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0)) forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0)) reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0)) reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)), dict(name=str(mxreshape.id) + '/ReshapeShape')) if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0): reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape', dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node() reshape_shape_node.in_port(0).connect(reshape_node.out_port(0)) backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node() backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardUnsqueeze')) backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node() backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardSqueeze')) backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node() backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0)) backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0)) backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0)) backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0)) mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': np.array([-2])}).create_node() end = Const(graph, {'value': np.array([-1])}).create_node() stride = Const(graph, {'value': np.array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) concat_value = Const(graph, {'value': np.array([4])}).create_node() shape_concat = Concat(graph, {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': np.array(0)}).create_node() shape_part_for_tiling.out_port(0).connect(shape_concat.in_port(0)) concat_value.out_port(0).connect(shape_concat.in_port(1)) variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': np.array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def replace_pattern(graph: Graph, match: dict): node = match['conv'] node_name = node.soft_get('name', node.id) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) split = create_op_with_const_inputs(graph, VariadicSplit, { 1: int64_array(0), 2: int64_array([1, -1]) }, { 'name': shape.name + '/split_batch', 'out_ports_count': 2 }, shape) pow_node = create_op_node_with_second_input( graph, Pow, int64_array([-1]), {'name': node_name + '/patch_stride/inverse'}) conv_patch_stride = Const( graph, { 'value': int64_array([node.patch_stride]), 'name': node_name + '/patch_stride/' }).create_node() pow_node.in_port(0).connect(conv_patch_stride.out_port(0)) mul = Mul(graph, { 'name': node_name + '/mul_inverse_stride_h' }).create_node() mul.in_port(0).connect(split.out_port(1)) mul.in_port(1).connect(pow_node.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, {2: int64_array([1])}, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(split.out_port(0)) concat.in_port(1).connect(mul.out_port(0)) concat.in_port(3).connect(conv_patch_stride.out_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): merge = match['merge'] power = Pow(graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM' }).create_node() const1 = Const(graph, { 'value': -1.0, 'name': merge.name + '/negate_const' }).create_node() merge.in_port(0).get_connection().set_destination(power.in_port(0)) const1.out_port(0).connect(power.in_port(1)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(merge.significant, merge.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(10.0, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) power.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(10.0, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = Cast( graph, { 'name': merge.name + '/Cast_' + idx.__str__(), 'dst_type': np.float32 }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': False, 'in_ports_count': 2, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array([1, 5]), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid( 'axis'), 'Flatten {} has no mandatory `axis` attribute'.format( name) assert node.has_valid( 'end_axis' ), 'Flatten {} has no mandatory `end_axis` attribute'.format(name) axis = node.axis end_axis = node.end_axis if end_axis == -1 and axis >= 0: begin_dims = Const(graph, { 'value': int64_array([0] * axis) }).create_node() middle_dim = Const(graph, { 'value': int64_array([-1]) }).create_node() end_dims = Const(graph, {'value': int64_array([])}).create_node() else: rank = Rank(graph, {'name': name + '/input_rank'}).create_node() node.in_port(0).get_source().connect(rank.in_port(0)) shape = Shape(graph, {'name': name + '/input_shape'}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) begin_dims = get_shape_values_by_range_idxs(shape=shape, rank=rank, begin=0, end=axis) middle_dims = get_shape_values_by_range_idxs(shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True) end_dims = get_shape_values_by_range_idxs(shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True) middle_dim = create_op_node_with_second_input( graph, ReduceProd, int64_array([0]), {'keep_dims': True}) middle_dims.out_port(0).connect(middle_dim.in_port(0)) dim = new_shape_node_from_shape_nodes( [begin_dims, middle_dim, end_dims]) original_name = node.soft_get('name') abandoned_name = original_name + '/ShouldBeDeleted' reshape_node = Reshape(graph, {}).create_node() # Keep node with the same name to avoid confuse with renaming rename_nodes([(node, abandoned_name), (reshape_node, original_name)]) reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0))
def mxrepeat_decomposition(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') # Unqueeze input_rank = Rank(graph, {'name': name + '/Rank'}).create_node() node.in_port(0).get_source().connect(input_rank.in_port(0)) axis = get_canonical_axis_index_node(input_rank, node.axis) unsqueeze_axis = create_op_node_with_second_input( graph, Add, int64_array([1]), {'name': name + '/Unsqueeze/Axis'}, input_node=axis) unsqueeze = Unsqueeze(graph, { 'name': name + '/Unsqueeze' }).create_node() unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0)) # Tile (1, 1, ..., repeats, ..., 1) # we generate tile array according to the following table: # parts: | first | repeats | second | # i: | 0, 1, ..., axis,| axis + 1,| ..., rank+1 | # tile_array: | 1, 1, ..., 1 ,| repeats ,| ..., 1 | one = Const(graph, { 'name': name + '/Broadcast/One', 'value': int64_array([1]) }).create_node() first_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_first_part' }).create_node() first_ones.in_port(0).connect(one.out_port(0)) first_ones.in_port(1).connect(unsqueeze_axis.out_port(0)) repeats = Const(graph, { 'name': name + '/repeats', 'value': int64_array([node.repeats]) }).create_node() second_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_second_part' }).create_node() second_part_broadcast_shape = Sub( graph, { 'name': name + '/Broadcast/Shape/second_part' }).create_node() second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0)) second_part_broadcast_shape.in_port(1).connect( unsqueeze_axis.out_port(0)) second_ones.in_port(0).connect(one.out_port(0)) second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0)) tile_repeats = new_shape_node_from_shape_nodes( [first_ones, repeats, second_ones]) tile = Tile(graph, {'name': name + '/Tile'}).create_node() tile.in_port(1).connect(tile_repeats.out_port(0)) # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:]) # we generate reshape dim array according to the following table: # parts: | first | rep | second | # i: | 0, 1, ... ,| axis, | ..., rank | # dim_array: | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] | input_shape = Shape(graph, {'name': name + '/Shape'}).create_node() node.in_port(0).get_source().connect(input_shape.in_port(0)) first_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=0, end=node.axis, include_begin=True, include_end=False) original_axis_dim = create_op_with_const_inputs( graph, Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'}, input_node=input_shape) original_axis_dim.in_port(1).connect(axis.out_port(0)) repeated_dimention = Mul(graph, { 'name': name + '/RepeatedDim' }).create_node() repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0)) repeated_dimention.in_port(1).connect(repeats.out_port(0)) second_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=node.axis, end=-1, include_begin=False, include_end=True) output_shape = new_shape_node_from_shape_nodes([ first_input_shape_part, repeated_dimention, second_input_shape_part ]) reshape = Reshape(graph, {'name': name}).create_node() rename_node(reshape, name) reshape.in_port(1).connect(output_shape.out_port(0)) # Final connections node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0)) tile.in_port(0).connect(unsqueeze.out_port(0)) reshape.in_port(0).connect(tile.out_port(0)) node.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node) initial_spatial_dims_node = node_to_get_spatial_dimensions_value( initial_shape_op_node) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'across_channels': 1, 'normalize_variance': 1, 'eps': group_norm_node.eps }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_shape_op_node_float = Cast( graph, { 'name': initial_shape_op_node.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_shape_op_node.out_port(0).connect( initial_shape_op_node_float.in_port(0)) initial_batch_dim_node = node_to_get_batch_value( initial_shape_op_node_float) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node_float) initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value( initial_shape_op_node) initial_spatial_dims_node = Cast( graph, { 'name': initial_spatial_dims_node_int.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_spatial_dims_node_int.out_port(0).connect( initial_spatial_dims_node.in_port(0)) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node_float = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) new_shape_node = Cast(graph, { 'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64 }).create_node() new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'normalize_variance': 1, 'eps': group_norm_node.eps, 'eps_mode': 'inside_sqrt' }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # MVN axes _, rank = get_shape_and_rank_nodes_by_port( mvn_node.in_port(0).get_connection().get_source(), return_as_a_scalar=True) rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(1), 2: int64_array(1) }, { 'name': group_norm_node.name + '/Range', 'output_type': np.int64 }) mvn_node.in_port(1).connect(rng.out_port(0)) rng.in_port(1).connect(rank.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))