def replace_pattern(graph: Graph, match: dict): node = match['op'] if not node.has_port('in', 2) or node.in_port( 2).disconnected() or not node.has_and_set('shape_input'): return if node.has_valid('layout') and not node.layout.startswith( 'NC') and graph.graph['layout'] == 'NCHW': input_shape_rank = len(node.in_port(0).data.get_shape()) permutation = PermuteAttrs.get_nhwc_to_nchw_permutation( input_shape_rank) data_node = node.in_node(2) name = node.soft_get('name', node.id) + '/ShapeGather' const = Const( graph, { 'value': permutation.perm, 'name': name + '/Const', 'need_shape_inference': True }).create_node_with_data() axis_const = Const(graph, { 'value': int64_array(0), 'name': name + '/Axis' }).create_node_with_data() gather = Gather(graph, { 'name': name, 'need_shape_inference': True }).create_node_with_data([data_node, const, axis_const]) attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy() graph.add_edge(gather.id, node.id, **attrs) graph.remove_edge(data_node.id, node.id)
def infer(node: Node): tf_strided_slice_infer(node) if node.graph.graph['layout'] == 'NHWC' and node.out_port( 0).data.get_value() is None: PermuteAttrs.create_permute_attrs( node, attrs=[ ('shrink_axis_mask', 'input:0', permute_masks), ('new_axis_mask', 'input:0', permute_masks), ('ellipsis_mask', 'input:0', permute_masks), ('begin_mask', 'input:0', permute_masks), ('end_mask', 'input:0', permute_masks), ]) for i in range(1, len(node.in_nodes())): if node.in_node( i).value is not None and node.in_node(i).shape[0] > 3: perm = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.in_node(0).shape)) node.in_node(i).value = permute_array_with_ellipsis( node, perm, node.in_node(i).value, 0) # due to permutation from nhwc to nchw we will extend all masks and inputs idx = np.nonzero(node.ellipsis_mask) node.ellipsis_mask[idx] = 0
def permute_array_with_ellipsis(node: Node, permutation: PermuteAttrs.Permutation, array: np.array, ins_value: int): """ This function permutes masks according to permutation parameter. Several cases should be processed: * Some dimensions can be omitted in mask according to ellipsis mask * Mask length can be less than length of output dimensions plus shrinked dimensions * Mask have the same or more length than output """ attr_mask_extended = list(array) # If input and output have length of shape 3 and less, no need to permute if len(node.in_node().shape) < 4 and len(node.out_node().shape) < 4: return attr_mask_extended # Length of mask is less than length of output ()plus shrinked dimensions then we should extend it before permutation if len(attr_mask_extended) < len( node.out_node(0).shape) + np.count_nonzero(node.shrink_axis_mask): # ellipsis is set, add dimensions in right place otherwise insert in the end if np.any(node.ellipsis_mask): idx = np.nonzero(node.ellipsis_mask) assert len(idx[0]) == 1 id = idx[0][0] else: id = len(attr_mask_extended) - 1 ellips_ext = len(node.out_node(0).shape) + np.count_nonzero( node.shrink_axis_mask) - len(attr_mask_extended) for i in range(0, ellips_ext): attr_mask_extended.insert(id + i + 1, ins_value) # permute extended mask perm = PermuteAttrs.get_nhwc_to_nchw_permutation( len(attr_mask_extended)) attr_mask_extended = np.array(attr_mask_extended)[perm.perm] return attr_mask_extended else: perm_len = len(node.out_node(0).shape) + np.count_nonzero( node.shrink_axis_mask) perm = PermuteAttrs.get_nhwc_to_nchw_permutation(perm_len) perm_list = list(perm.perm) # if mask length is more than output, just add tail that will not be permuted to avoid error for i in range(perm_len, len(attr_mask_extended)): perm_list.append(i) return np.array(attr_mask_extended, dtype=np.int64)[np.array(perm_list)]
def permute_array(node: Node, array: np.array): """ This function permutes masks according to permutation parameter. Mask have the same or more length than output """ attr_mask_extended = list(array) # If input and output have length of shape 3 and less, no need to permute if len(node.in_port(0).data.get_shape()) < 4 and len(node.out_port(0).data.get_shape()) < 4: return attr_mask_extended perm_len = len(node.out_port(0).data.get_shape()) + np.count_nonzero(node.shrink_axis_mask) perm = PermuteAttrs.get_nhwc_to_nchw_permutation(perm_len) perm_list = list(perm.perm) # if mask length is more than output, just add tail that will not be permuted to avoid error for i in range(perm_len, len(attr_mask_extended)): perm_list.append(i) return int64_array(attr_mask_extended)[int64_array(perm_list)]
def strided_slice(op_node: Node, port_info: str, input_port: int): """ StridedSLice must be permuted even if input or output tensors have rank lesser than 4 e.g. input_shape = (1, 10, 10), out = input[:, 0:10, :, new_axis], input_rank < 4 input_shape = (1, 10, 10, 3), out = input[:, 0:5, 0:4, 0], output_rank < 4 in both examples slice_rank is >= 4 slice_rank is defined by length of begin, end, strides (they all are of the same length) """ permutation_data_node = get_node_with_permutation(op_node, port_info) assert permutation_data_node.has_and_set('permutation'), 'Data node "{}" does not have permutation for node {}, ' \ 'port_info "{}".'.format(permutation_data_node.id, op_node.id, port_info) permute_indices_for_gather = permutation_data_node.permutation.perm if len(permute_indices_for_gather) == 0: return from mo.ops.op import PermuteAttrs slice_rank = op_node.in_port(input_port).data.get_shape()[0] # length of begin, end or strides permute_indices_for_gather = PermuteAttrs.get_nhwc_to_nchw_permutation(slice_rank).perm reorder_inputs_for_shape_or_slice(op_node, input_port, permute_indices_for_gather)
def find_and_replace_pattern(self, graph: Graph): for node in list(graph.nodes()): node = Node(graph, node) # Check that node layout mismatch with graph layout # For example: NHWC and NCHW or NCDHW and NDHWC if node.kind == 'op' and node.has_valid( 'layout') and node.layout != indices_mapping[len( node.layout)][graph.graph['layout']]: input = node.in_node() output = node.out_node() # Calculate permutation for further Transpose operations if graph.graph['layout'] == 'NCHW': # if Node has NCHW and graph has NHWC layout permutation = PermuteAttrs.get_nhwc_to_nchw_permutation( len(node.layout)) else: # if Node has NHWC and graph has NCHW layout permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(node.layout)) # Schematic representation of transformation below # # \ NCHW NCHW # NHWC -- \ | permutation permutation | # data-->Convolution(example)-->data -- / | | NCHW | | # / data->Transpose->data->Convolution->data->Transpose->data # 1. Insert input Transpose # This Transpose will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input.id, node.id)[0] graph.remove_edge(input.id, node.id) input_order_const = Const(graph, { 'value': permutation.perm }).create_node_with_data() input_permute_op = Transpose( graph, dict(name=node.name + '/Transpose_')) input_permute_data_node = input_permute_op.create_node_with_data( [input, input_order_const]) graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs) # 2. Insert output Transpose # This Transpose will permute output from operation layout to original input layout edge_attrs = graph.get_edge_data(node.id, output.id)[0] graph.remove_edge(node.id, output.id) input_data_node = Op.create_data_node( graph, node, {'shape': output.shape[permutation.perm]}, edge_attrs) output_order_const = Const(graph, { 'value': permutation.inv }).create_node_with_data() output_permute_op = Transpose( graph, dict(name=node.name + '/Transpose_')).create_node_with_data( [input_data_node, output_order_const], data_nodes=output) # 3. Add permutations for Node # Here we use permutation mechanism where data nodes takes permutation attribute. # And then we call permute_attrs method that permutes node attributes according to permutations on # data nodes. node.in_node()['permutation'] = permutation node.out_node()['permutation'] = permutation node.permute_attrs.permute_attrs(node) node.in_node()['permutation'] = None node.out_node()['permutation'] = None
def create_topK_net(shape, k, ir_version, use_new_frontend): """ Tensorflow net: |-> Values Input -> TopK | |-> Indices IR net: |-> Values Input -> TopK | |-> Indices """ # # Create Tensorflow model # import tensorflow as tf tf.compat.v1.reset_default_graph() # Create the graph and model with tf.compat.v1.Session() as sess: shape_net = permute_nchw_to_nhwc(shape) input_tensor = tf.compat.v1.placeholder(tf.int32, shape=shape_net, name='Input') values, indices = tf.nn.top_k(input_tensor, k=k, sorted=True, name='Operation') tf.compat.v1.global_variables_initializer() tf_net = sess.graph_def # # Create reference IR net # topk_output_shape = shape.copy() inverse_nhwc_nchw = PermuteAttrs.get_nhwc_to_nchw_permutation( len(topk_output_shape)).inv topk_axis = permute_axis( len(topk_output_shape) - 1, inverse_nhwc_nchw) # we need to permute axis attribute topk_output_shape[topk_axis] = k ref_net = None if check_ir_version(10, None, ir_version) and not use_new_frontend: nodes_attributes = { 'input': { 'kind': 'op', 'type': 'Parameter' }, 'input_data': { 'shape': shape, 'kind': 'data' }, 'Const_k_input_data': { 'shape': [], 'kind': 'data' }, 'Const_k': { 'kind': 'op', 'type': 'Const' }, 'Const_k_data': { 'shape': [], 'kind': 'data' }, 'TopK': { 'kind': 'op', 'type': 'TopK', 'axis': topk_axis, 'mode': 'max', 'sort': 'value' }, 'TopK_data_1': { 'shape': topk_output_shape, 'kind': 'data' }, 'TopK_data_2': { 'shape': topk_output_shape, 'kind': 'data' }, 'result_1': { 'kind': 'op', 'type': 'Result' }, 'result_2': { 'kind': 'op', 'type': 'Result' }, } ref_net = build_graph(nodes_attributes, [ ('input', 'input_data'), ('input_data', 'TopK', { 'in': 0 }), ('Const_k_input_data', 'Const_k'), ('Const_k', 'Const_k_data'), ('Const_k_data', 'TopK', { 'in': 1 }), ('TopK', 'TopK_data_1', { 'out': 0 }), ('TopK', 'TopK_data_2', { 'out': 1 }), ('TopK_data_1', 'result_1'), ('TopK_data_2', 'result_2'), ]) return tf_net, ref_net
def permute_nhwc_to_nchw(shape): perm = PermuteAttrs.get_nhwc_to_nchw_permutation(len(shape)).perm new_shape = np.array(shape)[perm] return new_shape
def permute_nhwc_to_nchw(shape, use_new_frontend=False): if use_new_frontend: return shape perm = PermuteAttrs.get_nhwc_to_nchw_permutation(len(shape)).perm new_shape = np.array(shape)[perm] return new_shape