def replace_identityN(node: Node): graph = node.graph name = node.soft_get('name', node.id) assert node.has_valid( 'data_types'), 'IdentityN {} has no `data_types` attribute'.format( name) dtypes = node.data_types for idx, port in node.in_ports().items(): if not node.is_in_port_connected( idx) or not node.is_out_port_connected(idx): # ATTENTION section in the description above continue assert idx < len( dtypes ), 'IdentityN {} has inconsistent `data_types` attribute {}'.format( name, dtypes) identity = Identity(graph, { 'name': '{}/{}_port'.format(name, idx), 'data_type': dtypes[idx] }).create_node() port.get_connection().set_destination(identity.in_port(0)) node.out_port(idx).get_connection().set_source( identity.out_port(0)) # ATTENTION section in the description above for in_port in node.in_ports().values(): in_port.disconnect() for out_port in node.out_ports().values(): out_port.disconnect()
def re_numerate_input_ports(loop_node: Node): """ Update input ports ids to be consecutive from 0 to num_input_ports - 1 and update the port_map values of the Loop node. :param loop_node: the Loop node :return: None """ def re_number_input_port(loop_node: Node, old_port_id: int, new_port_id: int): loop_node.add_input_port(new_port_id, skip_if_exist=True) loop_node.in_port(old_port_id).get_connection().set_destination( loop_node.in_port(new_port_id)) Loop.update_port_map_value(loop_node.input_port_map, 'external_port_id', old_port_id, new_port_id) if len(loop_node.in_ports()) > 0: max_port_id = sorted(loop_node.in_ports().keys())[-1] new_port_id = 0 for port_id in range(max_port_id + 1): if loop_node.is_in_port_connected(port_id): if port_id != new_port_id: re_number_input_port(loop_node, port_id, new_port_id) new_port_id += 1 for port_idx_to_remove in reversed( range(new_port_id, max_port_id + 1)): if port_idx_to_remove in loop_node.in_ports().keys(): loop_node.delete_input_port(port_idx_to_remove)
def get_rnn_input_size(node: Node): node_name = node.soft_get('name', node.id) assert node.is_in_port_connected(1), 'weights input is not connected' if node.format == 'onnx': # ONNX weights on input 1 contain only W part, R, and B are connected separately # weights_shape = `[num_directions, 4 * hidden_size, input_size]` weights_size = node.in_port(1).data.get_shape() assert len( weights_size ) == 3, 'incorrect weights ranks for MXNet {} node {}'.format( node.op, node_name) input_size = weights_size[2] return input_size elif node.format == 'mxnet': multiplier = node.multiplier hidden_size = node.hidden_size num_layers = node.num_layers direction = 2 if node.has_num_directions else 1 # for MXNet models we always get flattened weights which contains WRB weights_size = node.in_port(1).data.get_shape() assert len( weights_size ) == 1, 'incorrect weights ranks for MXNet {} node {}'.format( node.op, node_name) weights_size = weights_size[0] size = hidden_size * direction * multiplier other_layer_params_size = (hidden_size * direction + hidden_size + 2) * size first_layer_params_size = weights_size - (num_layers - 1) * other_layer_params_size # lhe lines above to find first_layer_params_size was taken from MXNetSplitMultiLayers.py:79 # input_size can be calculated from the first_layer_params_size # if first_layer_params_size = (input_size + hidden_size + 2) * size # then input_size = first_layer_params_size / size - 2 - hidden_size input_size = first_layer_params_size / size - 2 - hidden_size return input_size elif node.format == 'tf': log.error( 'reverse infer for TensorFlow RNN operation {} is not implemented yet' .format(node_name), extra={'is_warning': True}) else: raise Error('Incorrect framework name')
def extend_inputs(node: Node, num_insertions: int): graph = node.graph node_name = node.soft_get('name', node.id) for i, input_name in [(1, 'begin'), (2, 'end'), (3, 'strides')]: if i == 3 and not node.is_in_port_connected(3): continue # no need to extend strides if they are not connected blank_values_arr = np.zeros( num_insertions) if input_name != 'strides' else np.ones( num_insertions) blank_values_node = Const( graph, { 'name': node_name + '/extend_{}_const'.format(input_name), 'value': int64_array(blank_values_arr) }).create_node() if node.in_port(i).get_source().node.soft_get('type') == 'Concat': # concat already exists concat = node.in_port(i).get_source().node # because output data node shape will be changed # while shapes will be reinferred no need to check consistency concat['override_output_shape'] = True last_in_port = max(concat.in_ports().keys()) assert not concat.in_port(last_in_port).disconnected(), 'The last in_port of Concat node {} ' \ 'should be connected'. \ format(concat.soft_get('name', node.id)) concat.add_input_port(last_in_port + 1) concat.in_port(last_in_port + 1).connect( blank_values_node.out_port(0)) else: # have to create concat concat = Concat( graph, { 'axis': 0, 'name': node_name + '/concat_{}'.format(input_name), 'in_ports_count': 2 }).create_node() node.in_port(i).get_connection().set_destination( concat.in_port(0)) concat.in_port(1).connect(blank_values_node.out_port(0)) concat.out_port(0).get_connection().set_destination( node.in_port(i))
def reverse_infer(node: Node): input_shape = node.in_port(0).data.get_shape() if input_shape is None and node.is_in_port_connected(2) and node.in_port(2).data.get_shape() is not None: shape = undefined_shape_of_rank(node.in_port(2).data.get_shape()[0]) node.in_port(0).data.set_shape(shape)
def quantize_to_fakequantize(graph: Graph, quantize_node: Node, set_stop_value_propagation=False): node_name = quantize_node.soft_get('name', quantize_node.id) axis = quantize_node.soft_get('axis', None) scale_y_shape = quantize_node.in_port(1).data.get_shape() if quantize_node.is_in_port_connected(2): zerop = quantize_node.in_port(2).get_source().node else: zerop = Const( graph, { 'value': mo_array(0, dtype=np.uint8), 'name': node_name + '/ZeroPoint' }).create_node() assert zerop.soft_get( 'type' ) == 'Const', 'only constant for zero_point is supported for QuantizeLinear' zero_point_type = zerop.value.dtype # data type affects range of output values: [-128..127] or [0..255] if zero_point_type == np.int8: output_low_value = -128.0 output_high_value = 127.0 elif zero_point_type == np.uint8: output_low_value = 0.0 output_high_value = 255.0 else: raise Error( 'Not expected type {} for zero point value in node {}'.format( zero_point_type, zerop.soft_get('name'))) fake_quantize = create_op_with_const_inputs( graph, FakeQuantize, { 3: float_array(output_low_value), 4: float_array(output_high_value) }, { 'levels': 256, 'name': node_name + '/FakeQuantize' }) if set_stop_value_propagation: fake_quantize['stop_compression'] = True fake_quantize['stop_value_propagation'] = True quantize_node.in_port(0).get_connection().set_destination( fake_quantize.in_port(0)) # Calculate input_low value mul_low = create_op_with_const_inputs( graph, Mul, {1: float_array(output_low_value - zerop.value)}, {'name': node_name + '/Mul/Low'}) quantize_node.in_port(1).get_connection().set_destination( mul_low.in_port(0)) mul_low.out_port(0).connect(fake_quantize.in_port(1)) # Calculate input_high value mul_high = create_op_with_const_inputs( graph, Mul, {1: float_array(output_high_value - zerop.value)}, {'name': node_name + '/Mul/High'}) mul_low.in_port(0).get_connection().add_destination( mul_high.in_port(0)) mul_high.out_port(0).connect(fake_quantize.in_port(2)) cast = Cast(graph, { 'dst_type': zero_point_type, 'name': node_name + '/Cast' }).create_node() fake_quantize.out_port(0).connect(cast.in_port(0)) quantize_node.out_port(0).get_connection().set_source(cast.out_port(0)) rename_nodes([(quantize_node, node_name + '/TBD'), (cast, node_name)]) assert scale_y_shape is not None, "{0} contains scale(input with port 1) with shape None".\ format(quantize_node.soft_get('name', soft_get('id'))) if axis is not None and len( scale_y_shape) > 0 and scale_y_shape[0] > 1: input_shape = fake_quantize.in_port(0).data.get_shape() target_shape = np.ones(len(input_shape), np.int) target_shape[axis] = input_shape[axis] mul_low_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Mul/Low'}) mul_high_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Mul/high'}) fake_quantize.in_port(1).get_connection().set_destination( mul_low_reshape.in_port(0)) fake_quantize.in_port(2).get_connection().set_destination( mul_high_reshape.in_port(0)) mul_low_reshape.out_port(0).connect(fake_quantize.in_port(1)) mul_high_reshape.out_port(0).connect(fake_quantize.in_port(2))
def type_infer(node: Node): assert node.is_in_port_connected(1), 'The second input is not connected for a node {}.' \ ''.format(node.soft_get('name'), node.id) node.out_port(0).set_data_type(node.in_port(1).get_data_type())
def get_rnn_batch_size_and_seq_len(node: Node): """ Gets batch_size and sequence_length from RNN constant inputs and output shapes retrieved during reverse_infer :param node: :return: """ node_name = node.soft_get('name', node.id) out_shape = node.out_port(0).data.get_shape() batch_size = dynamic_dimension seq_len = dynamic_dimension in_port_with_initial_states = 3 # initial hidden size values is framework dependent if out_shape is not None: # note that op is not in opset state but in the state of the original framework if node.batch_dim == 1: seq_len = out_shape[0] if node.format == 'mxnet': assert len( out_shape ) == 3, 'incorrect out_shape rank for node {}'.format( node_name) # for MXNet out_shape = [seq_len, batch_size, hidden_size] batch_size = out_shape[1] in_port_with_initial_states = 2 elif node.format == 'onnx': assert len( out_shape ) == 4, 'incorrect out_shape rank for node {}'.format( node_name) # even for ONNX in extractor 'batch_dim': 1 (front/onnx/lstm_ext.py:26) despite the fact that # out_shape = [seq_len, num_directions, batch_size, hidden_size] batch_size = out_shape[2] in_port_with_initial_states = 5 elif node.format == 'tf': log.error( 'reverse infer for TensorFlow RNN operation {} is not implemented yet' .format(node_name), extra={'is_warning': True}) else: raise Error('Incorrect framework name') elif node.batch_dim == 0: # out_shape = [batch_size, num_directions, seq_len, hidden_size] batch_size = out_shape[0] seq_len = out_shape[2] in_port_with_initial_states = 3 else: raise Error('incorrect batch_dim for node {}'.format(node_name)) if batch_size is dynamic_dimension: if node.is_in_port_connected(in_port_with_initial_states): initial_hidden_state_size = node.in_port( in_port_with_initial_states).data.get_shape() if initial_hidden_state_size is not None: batch_size = initial_hidden_state_size[1] if seq_len is dynamic_dimension and node.format == 'onnx': # ONNX can store seq_len in optional input if node.is_in_port_connected(4): seq_len_val = node.in_port(4).data.get_value() if seq_len_val is not None: seq_len = seq_len.item() return [batch_size, seq_len]