def find_and_replace_pattern(self, graph: Graph): for merge in graph.get_op_nodes(op='Merge'): for merge_switch_in_port in range(2): if merge.in_port(merge_switch_in_port).disconnected() or \ merge.in_port(merge_switch_in_port).get_source().node.op != 'Switch': continue switch_2 = merge.in_port( merge_switch_in_port).get_source().node if merge.in_port(1 - merge_switch_in_port).disconnected() or \ merge.in_port(1 - merge_switch_in_port).get_source().node.op != 'Identity': continue false_value_port = merge.in_port( 1 - merge_switch_in_port).get_source() true_value_port = switch_2.in_port(0).get_source() op = false_value_port.node.in_port(0).get_source().node if op.in_port(0).disconnected( ) or op.in_port(0).get_source().node.op != 'Switch': continue switch = op.in_port(0).get_source().node if op.in_port(1).disconnected( ) or op.in_port(1).get_source().node.op != 'Switch': continue switch_1 = op.in_port(1).get_source().node if switch.in_port(1).get_source() == switch_1.in_port(1).get_source() and \ switch.in_port(1).get_source() == switch_2.in_port(1).get_source(): select = Select( graph, dict(name=merge.soft_get('name') + '/Select/', format='tf')).create_node() select.in_port(0).connect(switch.in_port(1).get_source()) select.in_port(1).connect(true_value_port) select.in_port(2).connect(false_value_port) merge.out_port(0).get_connection().set_source( select.out_port(0)) assert 1 in op.in_ports() and 0 in op.in_ports() op.in_port(0).disconnect() op.in_port(1).disconnect() switch.in_port(0).get_connection().set_destination( op.in_port(0)) switch_1.in_port(0).get_connection().set_destination( op.in_port(1)) graph.remove_nodes_from( nodes=[switch_1.id, switch.id, switch_2.id, merge.id]) # need to exit from the inner for loop because the Merge op has been removed break
def insert_select(graph: Graph, node: Node): context_len = node.frame_time + 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1]) select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_const_with_batch_from_input( in_node_port, context_len, precision=np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32) concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node: graph = fake_quantize.graph quantized_data = fake_quantize.in_port(0).get_source().node name = fake_quantize.soft_get('name', fake_quantize.id) assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \ 'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id)) dequantizing_cast = Cast( graph, dict(name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)), dst_type=dst_type, stop_value_propagation=True)).create_node() fake_quantize.in_port(0).get_connection().set_destination( dequantizing_cast.in_port(0)) # limits of dequantize in_low = fake_quantize.in_port(1).get_source() in_high = fake_quantize.in_port(2).get_source() out_low = fake_quantize.in_port(3).get_source() out_high = fake_quantize.in_port(4).get_source() # scale calculation output_range = Sub(graph, { 'name': name + '/output_range' }).create_node() output_range.in_port(0).connect(out_high) output_range.in_port(1).connect(out_low) input_range = Sub(graph, {'name': name + '/input_range'}).create_node() input_range.in_port(0).connect(in_high) input_range.in_port(1).connect(in_low) scale = Div(graph, {'name': name + '/scale'}).create_node() scale.in_port(0).connect(output_range.out_port(0)) scale.in_port(1).connect(input_range.out_port(0)) # shift calculation descaled_output_low = Div(graph, { 'name': name + '/descaled_output_low' }).create_node() descaled_output_low.in_port(0).connect(out_low) descaled_output_low.in_port(1).connect(scale.out_port(0)) shift = Sub(graph, {'name': name + '/shift'}).create_node() shift.in_port(0).connect(in_low) shift.in_port(1).connect(descaled_output_low.out_port(0)) zero = Const(graph, { 'name': name + '/zero', 'value': mo_array(0, dtype=dst_type) }).create_node() scale_eq_zero = Equal(graph, { 'name': name + '/scale_eq_zero' }).create_node() scale_eq_zero.in_port(0).connect(scale.out_port(0)) scale_eq_zero.in_port(1).connect(zero.out_port(0)) zero_point = Select(graph, { 'name': name + '/zero_point' }).create_node() zero_point.in_port(0).connect(scale_eq_zero.out_port(0)) zero_point.in_port(1).connect(zero.out_port(0)) zero_point.in_port(2).connect(shift.out_port(0)) # DeQuantize(x) == Mul(Sub(x, zero_point), scale) sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node() sub_zp.in_port(0).connect(dequantizing_cast.out_port(0)) sub_zp.in_port(1).connect(zero_point.out_port(0)) mul_scale = Mul(graph, { 'name': name + '/mulpiply_by_scale' }).create_node() mul_scale.in_port(0).connect(sub_zp.out_port(0)) mul_scale.in_port(1).connect(scale.out_port(0)) fake_quantize.out_port(0).get_connection().set_source( mul_scale.out_port(0)) graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])