def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] slice_name = node.soft_get('name', node.id) slice_node = Slice(graph).create_node() rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node() minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node() int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node() select_node = Select(graph, {'name': slice_name + '/select'}).create_node() # node to convert sizes to ends sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node() # reconnect input from tfslice to slice node.in_port(0).get_source().connect(slice_node.in_port(0)) node.in_port(0).disconnect() # reconnect begin of tfslice to start of slice node.in_port(1).get_source().connect(slice_node.in_port(1)) node.in_port(1).disconnect() # (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice # connects begins to slice slice_node.in_port(1).get_source().connect(sum_node.in_port(0)) node.in_port(2).get_source().connect(sum_node.in_port(1)) node.in_port(2).disconnect() # if size[i] == -1 when take int32_max as end[i] sum_node.in_port(1).get_source().connect(eq_node.in_port(0)) minus_one_node.out_port(0).connect(eq_node.in_port(1)) # from equal to 0 port of select eq_node.out_port(0).connect(select_node.in_port(0)) # from int32_max to 1 of select int32_max_node.out_port(0).connect(select_node.in_port(1)) # from sum to 2nd of select sum_node.out_port(0).connect(select_node.in_port(2)) # out of select to end (2nd of slice) select_node.out_port(0).connect(slice_node.in_port(2)) cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node() select_node.in_port(2).get_connection().insert_node(cast) node.out_port(0).get_connection().set_source(slice_node.out_port(0))
def find_and_replace_pattern(self, graph: Graph): for merge in graph.get_op_nodes(op='Merge'): for merge_switch_in_port in range(2): if merge.in_port(merge_switch_in_port).disconnected() or \ merge.in_port(merge_switch_in_port).get_source().node.op != 'Switch': continue switch_2 = merge.in_port( merge_switch_in_port).get_source().node if merge.in_port(1 - merge_switch_in_port).disconnected() or \ merge.in_port(1 - merge_switch_in_port).get_source().node.op != 'Identity': continue false_value_port = merge.in_port( 1 - merge_switch_in_port).get_source() true_value_port = switch_2.in_port(0).get_source() op = false_value_port.node.in_port(0).get_source().node if op.in_port(0).disconnected( ) or op.in_port(0).get_source().node.op != 'Switch': continue switch = op.in_port(0).get_source().node if op.in_port(1).disconnected( ) or op.in_port(1).get_source().node.op != 'Switch': continue switch_1 = op.in_port(1).get_source().node if switch.in_port(1).get_source() == switch_1.in_port(1).get_source() and \ switch.in_port(1).get_source() == switch_2.in_port(1).get_source(): select = Select( graph, dict(name=merge.soft_get('name') + '/Select/', format='tf')).create_node() select.in_port(0).connect(switch.in_port(1).get_source()) select.in_port(1).connect(true_value_port) select.in_port(2).connect(false_value_port) merge.out_port(0).get_connection().set_source( select.out_port(0)) assert 1 in op.in_ports() and 0 in op.in_ports() op.in_port(0).disconnect() op.in_port(1).disconnect() switch.in_port(0).get_connection().set_destination( op.in_port(0)) switch_1.in_port(0).get_connection().set_destination( op.in_port(1)) graph.remove_nodes_from( nodes=[switch_1.id, switch.id, switch_2.id, merge.id]) # need to exit from the inner for loop because the Merge op has been removed break
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_zero_value_with_batch_from_input( in_node_port, context_len, np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='Memory', index=1, shape=int64_array([context_len]))), ('mem_in_data', dict()), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Memory', index=0, shape=int64_array([context_len]))), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: mem_out = Memory( graph, { 'name': 'iteration_number', 'size': 2, 'index': 1, 'id': 'iteration_' + node.name, 'shape': int64_array([context_len]), 'dst_type': np.int32 }).create_node() cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Memory( graph, { 'name': 'iteration_number_out', 'size': 2, 'index': 0, 'id': 'iteration_' + node.name, 'shape': int64_array([context_len]) }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) select_node.in_port(0).connect(input_port) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]): node = match['op'] name = node.name # Zero Point Nudging : Scale counting f_min = node.in_port(1).get_source() node.in_port(1).disconnect() f_max = node.in_port(2).get_source() node.in_port(2).disconnect() f_diff = Sub(graph, {'name': name + '/float_range'}).create_node() f_max.connect(f_diff.in_port(0)) f_min.connect(f_diff.in_port(1)) quant_min_value = int(node.narrow_range) quant_max_value = 2 ** node.num_bits - 1 i_diff = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node() scale = Div(graph, {'name': name + '/scale'}).create_node() f_diff.out_port(0).connect(scale.in_port(0)) i_diff.out_port(0).connect(scale.in_port(1)) # Zero Point Nudging : ZP from min counting descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node() f_min.connect(descaled_min.in_port(0)) scale.out_port(0).connect(descaled_min.in_port(1)) zero_point_from_min = Sub(graph, {'name': name + '/zero_point_from_min'}).create_node() quant_min = Const(graph, {'value': quant_min_value, 'name': name + '/quant_min'}).create_node() quant_min.out_port(0).connect(zero_point_from_min.in_port(0)) descaled_min.out_port(0).connect(zero_point_from_min.in_port(1)) # Zero Point Nudging : Nudged Zero Point counting zp_lesser_q_mi = Less(graph, {'name': name + '/zero_point_from_min_less_quant_min'}).create_node() zero_point_from_min.out_port(0).connect(zp_lesser_q_mi.in_port(0)) quant_min.out_port(0).connect(zp_lesser_q_mi.in_port(1)) zp_greater_q_ma = Greater(graph, {'name': name + '/zero_point_from_min_greater_quant_max'}).create_node() zero_point_from_min.out_port(0).connect(zp_greater_q_ma.in_port(0)) quant_max = Const(graph, {'value': quant_max_value, 'name': name + '/quant_max'}).create_node() quant_max.out_port(0).connect(zp_greater_q_ma.in_port(1)) rounded_zero_point_from_min = Round(graph, {'name': name + '/zero_point_from_min_rounding'}).create_node() zero_point_from_min.out_port(0).connect(rounded_zero_point_from_min.in_port(0)) nudged_zero_point = Select(graph, {'name': name + '/nudging_zp_1_select_less_condition'}).create_node() greater_condition = Select(graph, {'name': name + '/nudging_zp_2_select_greater_condition'}).create_node() greater_condition.in_port(0).connect(zp_greater_q_ma.out_port(0)) greater_condition.in_port(1).connect(quant_max.out_port(0)) greater_condition.in_port(2).connect(rounded_zero_point_from_min.out_port(0)) nudged_zero_point.in_port(0).connect(zp_lesser_q_mi.out_port(0)) nudged_zero_point.in_port(1).connect(quant_max.out_port(0)) nudged_zero_point.in_port(2).connect(greater_condition.out_port(0)) nudged_i_min = Sub(graph, {'name': name + '/nudged_i_min'}).create_node() quant_min.out_port(0).connect(nudged_i_min.in_port(0)) nudged_zero_point.out_port(0).connect(nudged_i_min.in_port(1)) nudged_i_max = Sub(graph, {'name': name + '/nudged_i_max'}).create_node() quant_max.out_port(0).connect(nudged_i_max.in_port(0)) nudged_zero_point.out_port(0).connect(nudged_i_max.in_port(1)) nudged_min = Mul(graph, {'name': name + '/nudged_min'}).create_node() nudged_i_min.out_port(0).connect(nudged_min.in_port(0)) scale.out_port(0).connect(nudged_min.in_port(1)) nudged_max = Mul(graph, {'name': name + '/nudged_max'}).create_node() nudged_i_max.out_port(0).connect(nudged_max.in_port(0)) scale.out_port(0).connect(nudged_max.in_port(1)) nudged_min.out_port(0).connect(node.in_port(1)) nudged_max.out_port(0).connect(node.in_port(2)) # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow node.add_input_port(3, skip_if_exist=True) node.add_input_port(4, skip_if_exist=True) node.in_port(3).connect(nudged_min.out_port(0)) node.in_port(4).connect(nudged_max.out_port(0)) FakeQuantize.update_node_stat(node, {'levels': node['levels']})
def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node: graph = fake_quantize.graph quantized_data = fake_quantize.in_port(0).get_source().node name = fake_quantize.soft_get('name', fake_quantize.id) assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \ 'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id)) dequantizing_cast = Cast( graph, dict(name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)), dst_type=dst_type, stop_value_propagation=True)).create_node() fake_quantize.in_port(0).get_connection().set_destination( dequantizing_cast.in_port(0)) # limits of dequantize in_low = fake_quantize.in_port(1).get_source() in_high = fake_quantize.in_port(2).get_source() out_low = fake_quantize.in_port(3).get_source() out_high = fake_quantize.in_port(4).get_source() # scale calculation output_range = Sub(graph, { 'name': name + '/output_range' }).create_node() output_range.in_port(0).connect(out_high) output_range.in_port(1).connect(out_low) input_range = Sub(graph, {'name': name + '/input_range'}).create_node() input_range.in_port(0).connect(in_high) input_range.in_port(1).connect(in_low) scale = Div(graph, {'name': name + '/scale'}).create_node() scale.in_port(0).connect(output_range.out_port(0)) scale.in_port(1).connect(input_range.out_port(0)) # shift calculation descaled_output_low = Div(graph, { 'name': name + '/descaled_output_low' }).create_node() descaled_output_low.in_port(0).connect(out_low) descaled_output_low.in_port(1).connect(scale.out_port(0)) shift = Sub(graph, {'name': name + '/shift'}).create_node() shift.in_port(0).connect(in_low) shift.in_port(1).connect(descaled_output_low.out_port(0)) zero = Const(graph, { 'name': name + '/zero', 'value': np.array(0, dtype=dst_type) }).create_node() scale_eq_zero = Equal(graph, { 'name': name + '/scale_eq_zero' }).create_node() scale_eq_zero.in_port(0).connect(scale.out_port(0)) scale_eq_zero.in_port(1).connect(zero.out_port(0)) zero_point = Select(graph, { 'name': name + '/zero_point' }).create_node() zero_point.in_port(0).connect(scale_eq_zero.out_port(0)) zero_point.in_port(1).connect(zero.out_port(0)) zero_point.in_port(2).connect(shift.out_port(0)) # DeQuantize(x) == Mul(Sub(x, zero_point), scale) sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node() sub_zp.in_port(0).connect(dequantizing_cast.out_port(0)) sub_zp.in_port(1).connect(zero_point.out_port(0)) mul_scale = Mul(graph, { 'name': name + '/mulpiply_by_scale' }).create_node() mul_scale.in_port(0).connect(sub_zp.out_port(0)) mul_scale.in_port(1).connect(scale.out_port(0)) fake_quantize.out_port(0).get_connection().set_source( mul_scale.out_port(0)) graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
def replace_sub_graph(self, graph: Graph, match: dict): seq_len_tf = match['seq_len'] transpose_tf = match['transpose'] ctc_greedy_decoder_tf = match['ctc_greedy_decoder'] cast_tf = match['cast'] ctc_loss_tf = match['ctc_loss'] sparse_to_dense_tf = match['sparse_to_dense'] output_sparse_to_dense_name = sparse_to_dense_tf.soft_get( 'name', sparse_to_dense_tf.id) output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id) ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get( 'name', ctc_greedy_decoder_tf.id) log.debug( 'Found CTCLossFrontReplacer pattern after {} with name {}'.format( ctc_greedy_decoder_tf.op, ctc_greedy_decoder_tf.name)) # create sequence mask node, sub-graph for transforming into sequence length and connect with consumers seq_len_tf_shape = seq_len_tf.soft_get('shape', None) if seq_len_tf_shape is None or len(seq_len_tf_shape) != 2: raise Error( 'The sequence length that is the second input to the CTCGreedyDecoder node "{}"' ' must be specified in a mask format.'.format( ctc_greedy_decoder_tf_name)) log.error( 'The format of input sequence length has been changed to a mask format', extra={'is_warning': True}) seq_len_tf_type = seq_len_tf.soft_get('data_type', None) seq_len_tf_name = seq_len_tf.soft_get('name', seq_len_tf.id) seq_mask_placeholder = Parameter( graph, { 'name': seq_len_tf_name, 'shape': seq_len_tf_shape, 'data_type': seq_len_tf_type }).create_node() reduce_to_seq_len_node = create_op_with_const_inputs( graph, ReduceSum, {1: np.array(1, dtype=np.int32)}, { 'name': seq_len_tf_name + '/ReduceToSeqLen', 'keep_dims': False }) reduce_to_seq_len_node.in_port(0).connect( seq_mask_placeholder.out_port(0)) seq_len_tf.out_port(0).get_connection().set_source( reduce_to_seq_len_node.out_port(0)) cast_fp_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) casted_seq_mask_node = Cast(graph, { 'name': seq_len_tf_name + '/CastToFP32', 'dst_type': cast_fp_type }).create_node() casted_seq_mask_node.in_port(0).connect( seq_mask_placeholder.out_port(0)) permuted_casted_seq_mask = create_op_with_const_inputs( graph, Transpose, {1: int64_array([1, 0])}, {'name': seq_len_tf_name + '/Permute'}) permuted_casted_seq_mask.in_port(0).connect( casted_seq_mask_node.out_port(0)) rename_nodes([(seq_len_tf, seq_len_tf_name + '/AbandonedName'), (seq_mask_placeholder, seq_len_tf_name)]) # create CTCGreedyDecoder node and set mask node ctc_merge_repeated_i = ctc_greedy_decoder_tf.soft_get( 'ctc_merge_repeated', ctc_greedy_decoder_tf.id) ctc_greedy_decoder = CTCGreedyDecoderOp( graph, { 'name': output_sparse_to_dense_name, 'ctc_merge_repeated': ctc_merge_repeated_i }).create_node() ctc_greedy_decoder.in_port(1).connect( permuted_casted_seq_mask.out_port(0)) rename_nodes([(sparse_to_dense_tf, output_sparse_to_dense_name + '/AbandonedName'), (ctc_greedy_decoder, output_sparse_to_dense_name)]) # create CTCLoss node and set attributes assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \ 'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name) assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \ 'The CTCLoss node "{}" misses "ctc_merge_repeated" attribute'.format(output_ctc_loss_name) assert ctc_loss_tf.has_valid('unique'), \ 'The CTCLoss node "{}" misses "unique" attribute'.format(output_ctc_loss_name) preprocess_collapse_repeated = ctc_loss_tf.preprocess_collapse_repeated ctc_merge_repeated = ctc_loss_tf.ctc_merge_repeated unique = ctc_loss_tf.unique ctc_loss = CTCLoss( graph, { 'name': output_ctc_loss_name, 'preprocess_collapse_repeated': preprocess_collapse_repeated, 'ctc_merge_repeated': ctc_merge_repeated, 'unique': unique }).create_node() rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'), (ctc_loss, output_ctc_loss_name)]) # connect logits ctc_greedy_decoder_tf.in_port(0).get_connection().set_destination( ctc_greedy_decoder.in_port(0)) ctc_loss.in_port(0).disconnect() transpose_tf.in_port(0).get_connection().add_destination( ctc_loss.in_port(0)) # connect logit lengths ctc_greedy_decoder_tf.in_port(1).disconnect() ctc_loss.in_port(1).connect(reduce_to_seq_len_node.out_port(0)) # connect labels to ctc_loss squeeze_op = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])}) cast_labels_op = Cast( graph, { 'name': output_sparse_to_dense_name + '/CastLabels', 'dst_type': np.int32 }).create_node() squeeze_op.in_port(0).connect(ctc_greedy_decoder.out_port(0)) cast_labels_op.in_port(0).connect(squeeze_op.out_port(0)) ctc_loss.in_port(2).connect(cast_labels_op.out_port(0)) # connect label lengths equal_op = create_op_with_const_inputs( graph, Equal, {1: np.array([-1], dtype=np.int32)}, {'name': output_sparse_to_dense_name + '/Equal'}) equal_op.in_port(0).connect(cast_labels_op.out_port(0)) labels_shape_op = Shape( graph, { 'name': output_sparse_to_dense_name + '/ShapeOf' }).create_node() labels_shape_op.in_port(0).connect(equal_op.out_port(0)) broadcast_one = create_op_with_const_inputs( graph, Broadcast, {0: np.array([1], dtype=np.int32)}, { 'mode': 'numpy', 'name': output_sparse_to_dense_name + '/One' }) broadcast_one.in_port(1).connect(labels_shape_op.out_port(0)) broadcast_zero = create_op_with_const_inputs( graph, Broadcast, {0: np.array([0], dtype=np.int32)}, { 'mode': 'numpy', 'name': output_sparse_to_dense_name + '/Zero' }) broadcast_zero.in_port(1).connect(labels_shape_op.out_port(0)) select_node = Select(graph, { 'name': output_sparse_to_dense_name + '/Select' }).create_node() select_node.in_port(0).connect(equal_op.out_port(0)) select_node.in_port(1).connect(broadcast_zero.out_port(0)) select_node.in_port(2).connect(broadcast_one.out_port(0)) label_length_node = create_op_with_const_inputs( graph, ReduceSum, {1: int64_array([1])}, op_attrs={ 'name': output_sparse_to_dense_name + '/LabelLength', 'keep_dims': False }) label_length_node.in_port(0).connect(select_node.out_port(0)) ctc_loss.in_port(3).connect(label_length_node.out_port(0)) # set source for output of new sub-graph and remove old nodes ctc_loss_tf.out_port(0).get_connection().set_source( ctc_loss.out_port(0)) graph.remove_nodes_from([ ctc_greedy_decoder_tf.id, ctc_loss_tf.id, cast_tf.id, sparse_to_dense_tf.id ])
def insert_select(graph: Graph, node: Node): context_len = node.frame_time + 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, {'name': 'select_' + node.name}).create_node() zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1]) select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select')) ], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', {'in': 0}), ('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32) mem_out = ReadValue(graph, {'name': 'iteration_number', 'variable_id': 'iteration_' + node.name}).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32) concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign(graph, {'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name}).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1])}).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)