def add_sub_graph_call_output_tensors_transposes(node: Node): """ Adds transpose operations to the output nodes if they are 4D to change layout from NCHW to NHWC. :param node: the node to add transposes to the output nodes to. :return: None """ try: import tensorflow.compat.v1 as tf_v1 # disable eager execution of TensorFlow 2 environment immediately tf_v1.disable_eager_execution() except ImportError: import tensorflow as tf_v1 from mo.front.tf.partial_infer.tf import get_subgraph_output_tensors, add_node_def_to_subgraph _, output_tensors = get_subgraph_output_tensors(node) # transpose permutation constant nhwc_to_nchw_constant = tf_v1.constant(nhwc_to_nchw_permute, dtype=tf_v1.int32, name=nhwc_to_nchw_constant_name) # dummy node which we can refer to as input in the transpose for the output node dummy_node = tf_v1.constant(value=[[[[1]]]], dtype=tf_v1.float32, name='random_dummy_name') new_out_tensor_names = list() for out_tensor_name in node['output_tensors_names']: out_name, out_port = out_tensor_name.split(':') if len( output_tensors[int(out_port)].shape ) == 4: # TODO think about better check whether transpose is required out_transpose_name = out_name + '_port_' + out_port + '_transpose' transpose = tf_v1.transpose(dummy_node, nhwc_to_nchw_constant, name=out_transpose_name) # starting from TF 1.8 it is not possible to modify the "node_def" of the "tf.op", so we create a copy, # update it and use further new_input_names = transpose.op.node_def.input[:] new_input_names[0] = out_tensor_name new_node_def = copy.deepcopy(transpose.op.node_def) new_node_def.input[:] = new_input_names add_node_def_to_subgraph(node, new_node_def, position=len(node['nodes_order'])) new_out_tensor_names.append(out_transpose_name) else: new_out_tensor_names.append(out_tensor_name) # update output tensor names with transposes operations node['output_tensors_names'] = new_out_tensor_names
def update_placeholder_shape_and_add_transpose(node: Node): """ The function changes placeholders shapes from NHWC to NCHW format and add transpose operations if needed. :param node: node to operate on. :return: None """ import tensorflow as tf from mo.front.common.layout import convert_shape, nhwc_to_nchw_permute, nchw_to_nhwc_permute from mo.front.tf.extractors.utils import tf_tensor_shape from mo.front.tf.partial_infer.tf import add_node_def_to_subgraph, update_input_in_pbs tf.reset_default_graph() inputs_replacements = list() # transpose permutation constant nchw_to_nhwc_constant = tf.constant(nchw_to_nhwc_permute, dtype=tf.int32, name=nchw_to_nhwc_constant_name) nhwc_to_nchw_constant = tf.constant(nhwc_to_nchw_permute, dtype=tf.int32, name=nhwc_to_nchw_constant_name) for placeholder_name in node['input_nodes_names']: # dummy node which we can refer to as input in the transpose for the output node # dummy node should be unique for each placeholder dummy_node = tf.constant(value=[[[[1]]]], dtype=tf.float32, name='random_dummy_name_' + placeholder_name) placeholder = node['pbs'][placeholder_name] cur_shape = tf_tensor_shape(placeholder.attr['shape'].shape) if len( cur_shape ) == 4: # TODO think about better check that transpose is required nchw_shape = convert_shape(cur_shape, nhwc_to_nchw_permute) for ind in range(len(cur_shape)): placeholder.attr['shape'].shape.dim[ind].size = nchw_shape[ ind] transpose_name = placeholder.name + '_transpose' transpose = tf.transpose(dummy_node, nchw_to_nhwc_constant, transpose_name) # NCHW -> NHWC # add transpose operations to GraphDef after placeholders add_node_def_to_subgraph(node, transpose.op.node_def, transpose_name, len(node['input_nodes_names'])) inputs_replacements.append((placeholder.name, transpose_name)) inputs_replacements.append((dummy_node.name, placeholder.name)) node['real_input_dims'].append(nchw_shape) else: node['real_input_dims'].append(cur_shape) add_node_def_to_subgraph(node, nchw_to_nhwc_constant.op.node_def) add_node_def_to_subgraph(node, nhwc_to_nchw_constant.op.node_def) # update initial input names to a transposed ones for old_input_tensor_name, new_name in inputs_replacements: update_input_in_pbs(node, old_input_tensor_name, new_name)