def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] slice_name = node.soft_get('name', node.id) slice_node = Slice(graph).create_node() rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node() minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node() int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node() select_node = Select(graph, {'name': slice_name + '/select'}).create_node() # node to convert sizes to ends sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node() # reconnect input from tfslice to slice node.in_port(0).get_source().connect(slice_node.in_port(0)) node.in_port(0).disconnect() # reconnect begin of tfslice to start of slice node.in_port(1).get_source().connect(slice_node.in_port(1)) node.in_port(1).disconnect() # (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice # connects begins to slice slice_node.in_port(1).get_source().connect(sum_node.in_port(0)) node.in_port(2).get_source().connect(sum_node.in_port(1)) node.in_port(2).disconnect() # if size[i] == -1 when take int32_max as end[i] sum_node.in_port(1).get_source().connect(eq_node.in_port(0)) minus_one_node.out_port(0).connect(eq_node.in_port(1)) # from equal to 0 port of select eq_node.out_port(0).connect(select_node.in_port(0)) # from int32_max to 1 of select int32_max_node.out_port(0).connect(select_node.in_port(1)) # from sum to 2nd of select sum_node.out_port(0).connect(select_node.in_port(2)) # out of select to end (2nd of slice) select_node.out_port(0).connect(slice_node.in_port(2)) cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node() select_node.in_port(2).get_connection().insert_node(cast) node.out_port(0).get_connection().set_source(slice_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_zero_value_with_batch_from_input( in_node_port, context_len, np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def find_and_replace_pattern(self, graph: Graph): for embedding_segments_mean in graph.get_op_nodes( op='EmbeddingSegmentsMean'): embedding_segments_mean_name = embedding_segments_mean.soft_get( 'name', embedding_segments_mean.id) embedding_table_input = embedding_segments_mean.in_port(0) segment_ids_input = embedding_segments_mean.in_port(2) num_segments_input = embedding_segments_mean.in_port(3) # TODO: support EmbeddingSegmentsMean with specified weights vector. # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion # transformations do not handle it either if embedding_segments_mean.is_in_port_connected(5): return # 1. compute indices membership matrix, i.e. which indices belong to some object # the shape of this matrix is [num_segments, num_indices] non_norm_range_1_to_num_segments = create_op_with_const_inputs( graph, Range, { 0: int64_array(0), 2: int64_array(1) }, { 'name': embedding_segments_mean_name + '/Range1ToNumSegments', 'output_type': np.int64 }) num_segments_input.get_connection().add_destination( non_norm_range_1_to_num_segments.in_port(1)) range_1_to_num_segments = ConvertLike(graph, { 'name': embedding_segments_mean_name + '/Range1ToNumSegmentsNorm' }).create_node() range_1_to_num_segments.in_port(0).connect( non_norm_range_1_to_num_segments.out_port(0)) num_segments_input.get_connection().add_destination( range_1_to_num_segments.in_port(1)) unsqueeze_range_1_to_num_segments = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(1)}, { 'name': embedding_segments_mean_name + '/Range1ToNumSegmentsUnsqueeze' }) unsqueeze_range_1_to_num_segments.in_port(0).connect( range_1_to_num_segments.out_port(0)) unsqueeze_segment_ids = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(0)}, { 'name': embedding_segments_mean_name + '/SegmentIdsUnsqueeze' }) segment_ids_input.get_connection().add_destination( unsqueeze_segment_ids.in_port(0)) boolean_membership_matrix = Equal(graph, { 'name': embedding_segments_mean_name + '/BooleanMembershipMatrix' }).create_node() boolean_membership_matrix.in_port(0).connect( unsqueeze_range_1_to_num_segments.out_port(0)) boolean_membership_matrix.in_port(1).connect( unsqueeze_segment_ids.out_port(0)) shape_of_membership_matrix = Shape(graph, { 'name': embedding_segments_mean_name + '/ShapeOfMembershipMatrix' }).create_node([boolean_membership_matrix]) one_scalar_constant = Const( graph, { 'name': embedding_segments_mean_name + '/OneScalar', 'value': int64_array([1]) }).create_node() one_constant = Broadcast(graph, { 'name': embedding_segments_mean_name + '/One' }).create_node([one_scalar_constant, shape_of_membership_matrix]) zero_constant = Const( graph, { 'name': embedding_segments_mean_name + '/Zero', 'value': int64_array(0) }).create_node() membership_matrix = Select( graph, { 'name': embedding_segments_mean_name + '/MembershipMatrix', 'auto_broadcast': 'numpy' }).create_node( [boolean_membership_matrix, one_constant, zero_constant]) # 2. compute a number of indices belong to each object from the batch # it computes the normalization coefficients num_indices_per_object = create_op_with_const_inputs( graph, ReduceSum, {1: int64_array(1)}, { 'name': embedding_segments_mean_name + '/NumIndicesPerObject' }) num_indices_per_object.in_port(0).connect( membership_matrix.out_port(0)) # 3. replace zero coefficient (zero number of indices belong to an object) with one # because for such object the single default embedding vector is used where_zero_number = Equal(graph, { 'name': embedding_segments_mean_name + '/WhereZeroIndicesNumber' }).create_node([num_indices_per_object, zero_constant]) normalized_num_indices_per_object = Select( graph, { 'name': embedding_segments_mean_name + '/NormNumIndicesPerObject', 'auto_broadcast': 'numpy' }).create_node([ where_zero_number, one_scalar_constant, num_indices_per_object ]) # 4. cast normalized_num_indices_per_object to the same type as embedding vector table norm_coefficients = ConvertLike( graph, { 'name': embedding_segments_mean_name + '/NormCoefficients' }).create_node() norm_coefficients.in_port(0).connect( normalized_num_indices_per_object.out_port(0)) embedding_table_input.get_connection().add_destination( norm_coefficients.in_port(1)) # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum embedding_segments_sum = EmbeddingSegmentsSum( graph, { 'name': embedding_segments_mean_name + '/EmbeddingSegmentsSum' }).create_node() for in_port in embedding_segments_mean.in_ports(): if embedding_segments_mean.is_in_port_connected(in_port): embedding_segments_mean.in_port( in_port).get_connection().set_destination( embedding_segments_sum.in_port(in_port)) # 6. normalize EmbeddingSegmentSum results by computed coefficients result_node = Div(graph, { 'name': embedding_segments_mean_name + '/Div' }).create_node([embedding_segments_sum, norm_coefficients]) embedding_segments_mean.out_port(0).get_connection().set_source( result_node.out_port(0)) rename_nodes([(embedding_segments_mean, embedding_segments_mean_name + '/AbandonedName'), (result_node, embedding_segments_mean_name)]) graph.remove_nodes_from([embedding_segments_mean.id])
def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node: graph = fake_quantize.graph quantized_data = fake_quantize.in_port(0).get_source().node name = fake_quantize.soft_get('name', fake_quantize.id) assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \ 'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id)) dequantizing_cast = Cast( graph, dict(name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)), dst_type=dst_type, stop_value_propagation=True)).create_node() fake_quantize.in_port(0).get_connection().set_destination( dequantizing_cast.in_port(0)) # limits of dequantize in_low = fake_quantize.in_port(1).get_source() in_high = fake_quantize.in_port(2).get_source() out_low = fake_quantize.in_port(3).get_source() out_high = fake_quantize.in_port(4).get_source() # scale calculation output_range = Sub(graph, { 'name': name + '/output_range' }).create_node() output_range.in_port(0).connect(out_high) output_range.in_port(1).connect(out_low) input_range = Sub(graph, {'name': name + '/input_range'}).create_node() input_range.in_port(0).connect(in_high) input_range.in_port(1).connect(in_low) scale = Div(graph, {'name': name + '/scale'}).create_node() scale.in_port(0).connect(output_range.out_port(0)) scale.in_port(1).connect(input_range.out_port(0)) # shift calculation descaled_output_low = Div(graph, { 'name': name + '/descaled_output_low' }).create_node() descaled_output_low.in_port(0).connect(out_low) descaled_output_low.in_port(1).connect(scale.out_port(0)) shift = Sub(graph, {'name': name + '/shift'}).create_node() shift.in_port(0).connect(in_low) shift.in_port(1).connect(descaled_output_low.out_port(0)) zero = Const(graph, { 'name': name + '/zero', 'value': np.array(0, dtype=dst_type) }).create_node() scale_eq_zero = Equal(graph, { 'name': name + '/scale_eq_zero' }).create_node() scale_eq_zero.in_port(0).connect(scale.out_port(0)) scale_eq_zero.in_port(1).connect(zero.out_port(0)) zero_point = Select(graph, { 'name': name + '/zero_point' }).create_node() zero_point.in_port(0).connect(scale_eq_zero.out_port(0)) zero_point.in_port(1).connect(zero.out_port(0)) zero_point.in_port(2).connect(shift.out_port(0)) # DeQuantize(x) == Mul(Sub(x, zero_point), scale) sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node() sub_zp.in_port(0).connect(dequantizing_cast.out_port(0)) sub_zp.in_port(1).connect(zero_point.out_port(0)) mul_scale = Mul(graph, { 'name': name + '/mulpiply_by_scale' }).create_node() mul_scale.in_port(0).connect(sub_zp.out_port(0)) mul_scale.in_port(1).connect(scale.out_port(0)) fake_quantize.out_port(0).get_connection().set_source( mul_scale.out_port(0)) graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
def insert_select(graph: Graph, node: Node): context_len = node.frame_time + 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, {'name': 'select_' + node.name}).create_node() zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1]) select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select')) ], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', {'in': 0}), ('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32) mem_out = ReadValue(graph, {'name': 'iteration_number', 'variable_id': 'iteration_' + node.name}).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32) concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign(graph, {'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name}).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1])}).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)