def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = create_zero_value_with_batch_from_input( input_port, in_shape[1] * node_t) memory_out = ReadValue(graph, { 'name': pair_name, 'variable_id': node_name + pair_name }).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop( graph, { 'name': 'Memory_crop', 'dim': np.array([in_shape[1] * (node_t - 1)]), 'offset': np.array([in_shape[1]]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop( graph, { 'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]), 'offset': np.array([0]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) if not graph.graph['cmd_params'].static_shape: log.error( "Model can not be translated in a reshape-able way.\n" "Model Optimizer key static_shape was turned on to prevent related errors.\n" "There will be no success changing input shapes of the model with the help of " "InferenceEngine reshape method", extra={'is_warning': True}) graph.graph['cmd_params'].static_shape = True graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t) memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]), 'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]), 'offset': np.array([0]), 'axis': np.array([1])}).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_zero_value_with_batch_from_input( in_node_port, context_len, np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_op(self, graph: Graph, node: Node): input_out_port = node.in_port(0).get_source() memory_pair_input = unique_id('id') memory_pair_output = unique_id('id') # Input -> FullyConnected fc_layer_after_input_attrs = { 'name': 'input_fullyconnected', 'out-size': node.gifo_x_weights_shape[0], 'transpose_weights': True, 'bias_term': True, } fc_layer_after_input = FullyConnected( graph, fc_layer_after_input_attrs).create_node() fc_layer_after_input.in_port(0).connect(input_out_port) input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights) input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases) init_value_prev_lstm_output = create_zero_value_with_batch_from_input( input_out_port, node.gifo_r_weights_shape[1]) prev_lstm_output = ReadValue(graph, { 'name': 'prev_memory_output', 'variable_id': memory_pair_input }).create_node() prev_lstm_output.in_port(0).connect( init_value_prev_lstm_output.out_port(0)) # *Memory(output) -> FullyConnected fc_layer_from_prev_state_attrs = { 'name': 'prev_memory_output_fullyconnected', 'out-size': node.gifo_r_weights_shape[0], 'transpose_weights': True, 'bias_term': False, } fc_layer_from_prev_state = FullyConnected( graph, fc_layer_from_prev_state_attrs).create_node() fc_layer_from_prev_state.in_port(0).connect( prev_lstm_output.out_port(0)) input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights) # Memory -> FullyConnected \ # *Eltwise(sum) # Input -> FullyConnected / join_input_prev_state_sum = Add(graph, { 'name': 'join_input_eltwise' }).create_node() join_input_prev_state_sum.in_port(0).connect( fc_layer_from_prev_state.out_port(0)) join_input_prev_state_sum.in_port(1).connect( fc_layer_after_input.out_port(0)) # *Eltwise(sum) -> Split # it is split into 4 nodes: Act, Eltw*3 # the following order is mandatory # ___Tanh # / # Split ---(2)Eltwise(sum) # |\ # | \__(3)Eltwise(sum) # |____(4)Eltwise(sum) split_joined_input_axis = Const(graph, { 'value': np.int64(1) }).create_node() split_joined_input = Split(graph, { 'name': 'join_input_split', 'num_splits': 4, 'out_ports_count': 4 }).create_node() split_joined_input.in_port(0).connect( join_input_prev_state_sum.out_port(0)) split_joined_input.in_port(1).connect( split_joined_input_axis.out_port(0)) # prev_lstm_state = Memory(graph, {'name': 'prev_memory_state', # 'id': memory_pair_output, # 'index': 1, # 'size': 2, # 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64) # }).create_node() init_value_prev_lstm_state = create_zero_value_with_batch_from_input( split_joined_input.out_port(0), node.input_gate_weights.shape[0]) prev_lstm_state = ReadValue(graph, { 'name': 'prev_memory_state', 'variable_id': memory_pair_output }).create_node() prev_lstm_state.in_port(0).connect( init_value_prev_lstm_state.out_port(0)) # *Memory(state) -> *ScaleShift(input) state_input_scaleshift_attrs = { 'name': 'input_scaleshift', 'bias_term': False } state_input_scaleshift = ScaleShiftOp( graph, state_input_scaleshift_attrs).create_node() state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0)) input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights) # *Memory(state) -> *ScaleShift(forget) state_forget_scaleshift_attrs = { 'name': 'forget_scaleshift', 'bias_term': False } state_forget_scaleshift = ScaleShiftOp( graph, state_forget_scaleshift_attrs).create_node() state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0)) input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights) # Split \ # (2)Eltwise(sum) # Memory(state) -> *ScaleShift(input) / join_prev_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_input_eltwise' }).create_node() join_prev_lstm_input_joined_input_sum.in_port(0).connect( split_joined_input.out_port(1)) join_prev_lstm_input_joined_input_sum.in_port(1).connect( state_input_scaleshift.out_port(0)) # Split \ # (3)Eltwise(sum) # Memory(state) -> *ScaleShift(forget) / join_prev_lstm_input_joined_forget_sum = Add( graph, { 'name': 'join_prev_lstm_input_joined_forget_sum', }).create_node() join_prev_lstm_input_joined_forget_sum.in_port(0).connect( split_joined_input.out_port(2)) join_prev_lstm_input_joined_forget_sum.in_port(1).connect( state_forget_scaleshift.out_port(0)) # Split -> Tanh remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node() remember_tahn.in_port(0).connect(split_joined_input.out_port(0)) # Split -> (2)Eltwise(sum) -> *Sigmoid remember_sigmoid = Sigmoid(graph, { 'name': 'remember_sigmoid' }).create_node() remember_sigmoid.in_port(0).connect( join_prev_lstm_input_joined_input_sum.out_port(0)) # Split -> (3)Eltwise(sum) -> **Sigmoid forget_sigmoid = Sigmoid(graph, { 'name': 'forget_sigmoid' }).create_node() forget_sigmoid.in_port(0).connect( join_prev_lstm_input_joined_forget_sum.out_port(0)) # *Memory(state) \ # (6)Eltwise(mul) # Split -> (3)Eltwise(sum) -> **Sigmoid / join_forget_prev_state_mul = Mul(graph, { 'name': 'join_forget_prev_state_mul' }).create_node() join_forget_prev_state_mul.in_port(0).connect( forget_sigmoid.out_port(0)) join_forget_prev_state_mul.in_port(1).connect( prev_lstm_state.out_port(0)) # Split -> Tahn \ # (5)Eltwise(mul) # Split -> (2)Eltwise(sum) -> *Sigmoid / join_remember_candidates_mul = Mul( graph, { 'name': 'join_remember_candidates_mul' }).create_node() join_remember_candidates_mul.in_port(0).connect( remember_tahn.out_port(0)) join_remember_candidates_mul.in_port(1).connect( remember_sigmoid.out_port(0)) # (5)Eltwise(mul) \ # (7)Eltwise(sum) # (6)Eltwise(mul) / join_forget_remember_sum = Add(graph, { 'name': 'join_forget_remember_sum' }).create_node() join_forget_remember_sum.in_port(0).connect( join_forget_prev_state_mul.out_port(0)) join_forget_remember_sum.in_port(1).connect( join_remember_candidates_mul.out_port(0)) # (7)Eltwise(sum) -> Clamp join_forget_clamp = create_op_with_const_inputs( graph, Clamp, { 1: np.array(-node.clip_value, dtype=np.float32), 2: np.array(node.clip_value, dtype=np.float32) }, {'name': 'join_forget_clamp'}, join_forget_remember_sum) # # Clamp -> (2)Memory(state) next_lstm_state = Assign(graph, { 'name': 'next_lstm_state', 'variable_id': memory_pair_output }).create_node() next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0)) res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node() res_node.in_port(0).connect(next_lstm_state.out_port(0)) # Clamp -> (2)Tahn state_filtered_tahn = Tanh(graph, { 'name': 'state_filtered_tahn' }).create_node() state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0)) # Clamp -> (2)ScaleShift clamp_scaleshift_attrs = { 'name': 'clamp_scaleshift', 'bias_term': False } clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node() clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0)) input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights) # Split \ # (4)Eltwise(sum) # Clamp -> (2)ScaleShift / join_next_lstm_input_joined_input_sum = Add( graph, { 'name': 'join_next_lstm_input_joined_input_sum', }).create_node() join_next_lstm_input_joined_input_sum.in_port(0).connect( split_joined_input.out_port(3)) join_next_lstm_input_joined_input_sum.in_port(1).connect( clamp_scaleshift.out_port(0)) # (4)Eltwise(sum) -> (3)Sigmoid output_sigmoid = Sigmoid(graph, { 'name': 'output_sigmoid' }).create_node() output_sigmoid.in_port(0).connect( join_next_lstm_input_joined_input_sum.out_port(0)) # (4)Eltwise(sum) -> (3)Sigmoid \ # (5)Eltwise(mul) # Clamp -> (2)Tahn / joined_output_mul = Mul(graph, { 'name': 'joined_output_mul' }).create_node() joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0)) joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0)) # (5)Eltwise(mul) -> (3)FullyConnected fc_output_attrs = { 'name': 'FullyConnected', 'out-size': node.projection_weights_shape[0], 'transpose_weights': True, 'bias_term': False } fc_output = FullyConnected(graph, fc_output_attrs).create_node() fc_output.in_port(0).connect(joined_output_mul.out_port(0)) input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights) # / (2)Memory(output) # (3)FullyConnected # \ Output (any next node) (edge created automatically after replacement) next_lstm_output = Assign(graph, { 'name': 'next_lstm_output', 'variable_id': memory_pair_input }).create_node() next_lstm_output.in_port(0).connect(fc_output.out_port(0)) res_node_lstm_output = Result(graph, { 'name': 'next_lstm_output_out' }).create_node() res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0)) return [fc_output.id]
def replace_pattern(graph: Graph, match: dict): node = match['op'] in_shape = node.in_port(0).data.get_shape().copy() memory_element = in_shape[1] - node.const_dim memory_size = memory_element * len(node.context) memory_pair_id = unique_id('id') # Memory(in) input_memory = ReadValue(graph, { 'name': 'prev_splice_memory', 'variable_id': memory_pair_id }).create_node() # Memory(in) \ # Crop # Input(temp) / crop = Crop( graph, { 'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element]), 'dim': int64_array([memory_size - memory_element]) }).create_node() crop.in_port(0).connect(input_memory.out_port(0)) # Crop \ # Concat # Input / concat_node = Concat(graph, { 'name': 'Splice_Concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node.in_port(0).connect(crop.out_port(0)) # Concat -> Memory(out) mem_out = Assign(graph, { 'name': 'out_splice_memory', 'variable_id': memory_pair_id }).create_node() mem_out.in_port(0).connect(concat_node.out_port(0)) Result(graph).create_node().in_port(0).connect(mem_out.out_port(0)) if node.const_dim != 0: memory_element_constdim = node.const_dim memory_size_constdim = memory_element_constdim * len(node.context) split = create_op_with_const_inputs( graph, VariadicSplit, { 1: int64_array(1), 2: int64_array([memory_element, memory_element_constdim]) }, { 'name': node.id + '_split_const', 'out_ports_count': 2 }) split.out_port(0).connect(concat_node.in_port(1)) # create separate splice construction for const_dim memory_pair_id = unique_id('memory_for_const_dim') init_value_input_memory_const_dim = Const( graph, { 'name': 'init_value_const_dim_in_memory', 'value': np.zeros(int64_array([in_shape[0], memory_size_constdim])), 'shape': int64_array([in_shape[0], memory_size_constdim]) }).create_node() input_memory_const_dim = ReadValue(graph, { 'name': 'const_dim_in_memory', 'variable_id': memory_pair_id }).create_node() init_value_input_memory_const_dim.out_port(0).connect( input_memory_const_dim.in_port(0)) crop_const_dim = Crop( graph, { 'name': 'const_dim_crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element_constdim]), 'dim': int64_array( [memory_size_constdim - memory_element_constdim]) }).create_node() crop_const_dim.in_port(0).connect( input_memory_const_dim.out_port(0)) concat_node_const_dim = Concat(graph, { 'name': 'const_dim_concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node_const_dim.in_port(0).connect( crop_const_dim.out_port(0)) mem_out_const_dim = Assign(graph, { 'name': 'const_dim_out_memory', 'variable_id': memory_pair_id }).create_node() mem_out_const_dim.in_port(0).connect( concat_node_const_dim.out_port(0)) Result(graph).create_node().in_port(0).connect( mem_out_const_dim.out_port(0)) # connect splice to Split as begin and Concat as the end split.out_port(1).connect(concat_node_const_dim.in_port(1)) crop_first = Crop( graph, { 'name': 'const_dim_crop_first', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([memory_element_constdim]) }).create_node() crop_first.in_port(0).connect(concat_node_const_dim.out_port(0)) concat_const = Concat(graph, { 'name': node.id + '_concat_const', 'axis': 1, 'in_ports_count': 2 }).create_node() concat_const.in_port(1).connect(crop_first.out_port(0)) concat_const.in_port(0).connect(concat_node.out_port(0)) init_value_input_memory = Const( graph, { 'name': 'init_value_' + node.name, 'value': np.zeros(int64_array([in_shape[0], memory_size])), 'shape': int64_array([in_shape[0], memory_size]) }).create_node() init_value_input_memory.out_port(0).connect( input_memory.in_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) node.out_port(0).get_connection().set_source( concat_const.out_port(0)) else: init_value_input_memory = Const( graph, { 'name': 'init_value_' + node.name, 'value': np.zeros(int64_array([in_shape[0], memory_size])), 'shape': int64_array([in_shape[0], memory_size]) }).create_node() init_value_input_memory.out_port(0).connect( input_memory.in_port(0)) node.in_port(0).get_connection().set_destination( concat_node.in_port(1)) node.out_port(0).get_connection().set_source( concat_node.out_port(0)) # to avoid re-inference of shape and touching in next replacements graph.remove_node(node.id)
def insert_select(graph: Graph, node: Node): context_len = node.frame_time + 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, {'name': 'select_' + node.name}).create_node() zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1]) select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select')) ], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', {'in': 0}), ('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32) mem_out = ReadValue(graph, {'name': 'iteration_number', 'variable_id': 'iteration_' + node.name}).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32) concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign(graph, {'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name}).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1])}).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)