def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if pair_node.has_default: return if node.in_port(0).get_source() is not None: input_node_out_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_node_in_ports = pair_node.out_port(0).get_destinations() else: input_node_out_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_node_in_ports = node.out_port(0).get_destinations() in_shape = input_node_out_port.data.get_shape().copy() node_id = node.id node_name = node.name node_t = node.t splice = Splice(graph, {'name': node_name, 'id': node_id, 'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node() splice.in_port(0).connect(input_node_out_port) # offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0 crop = Crop(graph, {'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([max(0, in_shape[1] * node_t)]), 'dim': int64_array([in_shape[1]])}).create_node() splice.out_port(0).connect(crop.in_port(0)) splice.out_port(0).data.set_shape(int64_array([in_shape[0], (abs(node_t) + 1) * in_shape[1]])) outs = input_node_out_port.get_destinations() for in_port in outs: out_ = in_port.node if out_['op'] != 'MemoryOffset' and out_['op'] != 'Splice': crop_input = Crop(graph, {'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([-min(0, in_shape[1] * node_t)]), 'dim': int64_array([in_shape[1]])}).create_node() splice.out_port(0).connect(crop_input.in_port(0)) in_port.disconnect() crop_input.out_port(0).connect(in_port) crop_input.out_port(0).data.set_shape(in_shape) for dest_port in out_node_in_ports: dest_port.connect(crop.out_port(0)) graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t) memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]), 'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]), 'offset': np.array([0]), 'axis': np.array([1])}).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def find_and_replace_pattern(self, graph: Graph): for nms in graph.get_op_nodes(op='NonMaxSuppression'): # prepare inputs to the NonMaximumSuppression Node unsqueeze_boxes = create_op_node_with_second_input( graph, Unsqueeze, int64_array([0]), {'name': nms.soft_get('name') + '/Unsqueeze_0'}) nms.in_port(0).get_connection().insert_node(unsqueeze_boxes) unsqueeze_box_scores = create_op_node_with_second_input( graph, Reshape, int64_array([1, 1, -1]), {'name': nms.soft_get('name') + '/Unsqueeze_1'}) nms.in_port(1).get_connection().insert_node(unsqueeze_box_scores) nms_name = nms.soft_get('name', nms.id) # prepare output #0 crop_box_indices_name = nms_name + '/Crop_boxes_' crop_box_indices = Crop( graph, { 'name': crop_box_indices_name, 'axis': int64_array([1]), 'offset': int64_array([2]), 'dim': int64_array([1]) }).create_node() nms.out_port(0).get_connection().insert_node(crop_box_indices) squeeze_output_boxes = create_op_node_with_second_input( graph, Squeeze, int64_array([1]), {'name': crop_box_indices_name + '/Squeeze'}) crop_box_indices.out_port(0).get_connection().insert_node( squeeze_output_boxes) num_of_outputs = len([ port for port in nms.out_ports().values() if not port.disconnected() ]) if num_of_outputs == 1: continue # prepare output #1 crop_score_indices_name = nms_name + '/Crop_scores_' crop_score_indices = Crop( graph, { 'name': crop_score_indices_name, 'axis': int64_array([1]), 'offset': int64_array([2]), 'dim': int64_array([1]) }).create_node() nms.out_port(1).get_connection().insert_node(crop_score_indices) squeeze_output_scores = create_op_node_with_second_input( graph, Squeeze, int64_array([1]), {'name': crop_score_indices_name + '/Squeeze'}) crop_score_indices.out_port(0).get_connection().insert_node( squeeze_output_scores)
def replace_pattern(graph: Graph, match: dict): mem = match['op'] mem_shape = mem.in_port(0).data.get_shape() mem_parent = mem.in_port(0).get_source() context = mem['context'] for child_port in mem_parent.get_destinations(): child = child_port.node # check if we find Splice containing context 'context' if child['op'] == 'Splice' and child.id != mem.id and set( child['context']).issubset(set(context)): left_cont_out = child['context'][0] left_cont = context[0] for child_of_child in child.out_port(0).get_destinations(): out_transfer = child_of_child.node out_transfer_port = child_of_child if out_transfer['op'] == 'Crop': # modify existing Crop to get right data from larger Splice out_transfer['offset'] = out_transfer['offset'] + ( left_cont_out - left_cont) * mem_shape[-1] else: # insert Crop if we have not one child_of_child.disconnect() crop_node = Crop( graph, { 'name': graph.unique_id(prefix='Splice_crop_'), 'offset': (left_cont_out - left_cont) * mem_shape[-1], 'dim': np.array( [len(child['context']) * mem_shape[-1]]), 'axis': np.array([-1]) }).create_node() child.out_port(0).connect(crop_node.in_port(0)) crop_node.out_port(0).connect(child_of_child) crop_node.out_port(0).data.set_shape( child.out_port(0).data.get_shape()) out_transfer_port = crop_node.in_port(0) # move edge to child from old Splice to larger out_transfer_port.disconnect() mem.out_port(0).connect(out_transfer_port) graph.remove_node(child.id)
def replace_pattern(graph: Graph, match: dict): node = match['op'] node_id = node['variable_id'] out_node_port = node.out_port(0).get_destination() in_node_port = node.in_port(0).get_source() node.in_port(0).disconnect() node.out_port(0).disconnect() crop = Crop( graph, { 'name': 'Result_for_' + node_id, 'dim': np.array([1]), 'offset': np.array([0]), 'axis': np.array([0]) }).create_node() in_node_port.connect(crop.in_port(0)) crop.out_port(0).connect(out_node_port)
def replace_pattern(graph: Graph, match: dict): node = match['op'] node_id = node['id'] if node.in_port(0).disconnected(): i = 0 for dest in node.out_port(0).get_destinations(): new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id, 'shape': dest.data.get_shape()}).create_node() i += 1 dest.disconnect() new_in.out_port(0).connect(dest) log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id), extra={'is_warning': True}) else: out_node_port = node.out_port(0).get_destination() in_node_port = node.in_port(0).get_source() node.in_port(0).disconnect() node.out_port(0).disconnect() crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]), 'axis': np.array([0])}).create_node() in_node_port.connect(crop.in_port(0)) crop.out_port(0).connect(out_node_port)
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float): # create init_graph connected to ReadValue graph = input_out_port.node.graph input_name = input_out_port.node.name shape_of_input = Shape(graph, { 'name': 'shape/' + input_name }).create_node() shape_of_input.in_port(0).connect(input_out_port) dim_for_get_batch = Const( graph, { 'name': 'dim/crop_batch/' + shape_of_input.name, 'value': int64_array([1]), 'shape': int64_array([1]) }).create_node() get_batch = Crop( graph, { 'name': 'crop_batch/' + shape_of_input.name, 'axis': int64_array([0]), 'offset': int64_array([0]) }).create_node() get_batch.in_port(0).connect(shape_of_input.out_port(0)) get_batch.in_port(1).connect(dim_for_get_batch.out_port(0)) mem_shape_2nd_dim = Const( graph, { 'name': 'gifo_r_weights_shape/' + input_name, 'value': int64_array([second_dim]), 'shape': int64_array([1]) }).create_node() mem_shape = Concat( graph, { 'name': 'gather_memory_shape/' + input_name, 'axis': 0, 'in_ports_count': 2 }).create_node() mem_shape.in_port(0).connect(get_batch.out_port(0)) mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0)) fill_value = Const( graph, { 'name': 'fill_value/' + input_name, 'value': np.array([0.0], precision), 'shape': int64_array([1]) }).create_node() init_value_prev_lstm_output = Broadcast(graph, { 'name': 'init_value/' + input_name, }).create_node() init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0)) init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0)) return init_value_prev_lstm_output
def replace_sub_graph(graph: Graph, match: dict, **kwargs): nms = match['nms'] # prepare inputs to the NonMaximumSuppression Node unsqueeze_boxes = create_op_node_with_second_input( graph, Unsqueeze, int64_array([0]), {'name': nms.soft_get('name') + '/Unsqueeze_0'}) nms.in_port(0).get_connection().insert_node(unsqueeze_boxes) unsqueeze_box_scores = create_op_node_with_second_input( graph, Reshape, int64_array([1, 1, -1]), {'name': nms.soft_get('name') + '/Unsqueeze_1'}) nms.in_port(1).get_connection().insert_node(unsqueeze_box_scores) # prepare output crop_box_indices = Crop( graph, { 'name': nms.soft_get('name') + '/Crop', 'axis': int64_array([1]), 'offset': int64_array([2]), 'dim': int64_array([1]) }).create_node() nms.out_port(0).get_connection().insert_node(crop_box_indices) squeeze_output_boxes = create_op_node_with_second_input( graph, Squeeze, int64_array([1]), {'name': crop_box_indices.soft_get('name') + '/Squeeze'}) crop_box_indices.out_port(0).get_connection().insert_node( squeeze_output_boxes) if 5 in nms.in_ports() and not nms.in_port(5).disconnected(): soft_nms_sigma = nms.in_port(5).get_source().data.get_value() if soft_nms_sigma is not None and soft_nms_sigma != 0.0: log.error( 'The input to layer "{}" with value for the soft_nms_sigma is equal to "{}" but only value 0' 'is supported. The inference results will be incorrect.'. format(nms.soft_get('name'), soft_nms_sigma))
def replace_sub_graph(self, graph: Graph, match: dict): slice_like = match['slice_like'] const = slice_like.in_nodes()[0] crop_shape = slice_like.in_nodes()[1] variants_dict = { 'mul_scalar1x': 0.1, 'mul_scalar2x': 0.2, 'mul_scalar1y': 0.1, 'mul_scalar2y': 0.2 } for matches in find_pattern_matches(graph, self.variants_pattern['nodes'], self.variants_pattern['edges'], None, None): for k, v in matches.items(): if v in variants_dict.keys(): variants_dict[v] = Node(graph, k).in_nodes()[1].value[0] variants = np.array([ variants_dict['mul_scalar1x'], variants_dict['mul_scalar1y'], variants_dict['mul_scalar2x'], variants_dict['mul_scalar2y'] ] * int(const.value.size / 4)).reshape(const.value.shape) priorbox_variants = Const( graph, dict(value=variants, symbol_dict={'name': const.id + '/priorbox_variants'})).create_node() variants_slice_like = Crop(graph, dict(axis=slice_like.axis, offset=slice_like.offset, dim=slice_like.dim, axes=slice_like.axes, symbol_dict={'name': slice_like.id + '/variants_slice_like'})) \ .create_node() variants_slice_like.in_port(0).connect(priorbox_variants.out_port(0)) variants_slice_like.in_port(1).connect(crop_shape.out_port(0)) concat = match['reshape3'].out_port(0).get_destination().node assert concat.op == 'Concat' concat_nodes_count = len(concat.in_nodes()) concat.add_input_port(concat_nodes_count) concat.in_port(concat_nodes_count).get_connection().set_source( variants_slice_like.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['slice'] input = node.in_node(0) output_data = node.out_node() # ONNX 10 opset case if len(node.in_nodes()) >= 3 and node.has_valid( 'format') and node['format'] == 'onnx': self.convert_onnx_slice_opset10(node) return # Caffe case if not node.has_valid('start') or not node.has_valid('end'): return begin = node.start end = node.end axis = node.axis if node.has_valid('axis') else np.arange(begin.size) # Check whether operation use only one axis or not axes_begin = np.zeros(len(input.shape), dtype=np.int32) axes_end = np.zeros(len(input.shape), dtype=np.int32) ss_begin = np.zeros(len(input.shape), dtype=np.int32) ss_end = np.zeros(len(input.shape), dtype=np.int32) dims = 0 axes = np.zeros(begin.size) for i in range(len(axis)): if begin[i] != 0 or end[i] < input.shape[axis[i]]: dims += 1 axes[i] = 1 if begin[i] != 0: axes_begin[axis[i]] = 1 ss_begin[axis[i]] = begin[i] if end[i] < input.shape[axis[i]]: axes_end[axis[i]] = 1 ss_end[axis[i]] = end[i] axes = np.array(axes, dtype=bool) slice_node_name = node.soft_get('name', node.id) if dims == 1 or dims == 0: # If Slice use only one axis or no axis, than # convert Slice to StridedSlice ss = StridedSlice( graph, dict(new_axis_mask=np.zeros(len(output_data.shape), dtype=np.int32), shrink_axis_mask=np.zeros(len(output_data.shape), dtype=np.int32), ellipsis_mask=np.zeros(len(output_data.shape), dtype=np.int32), begin_mask=axes_begin, end_mask=axes_end)).create_node() convert_negative_indices(ss_begin, input.shape) convert_negative_indices(ss_end, input.shape) begin_node = Const(graph, { 'value': ss_begin, 'name': slice_node_name + '/begin' }).create_node() end_node = Const(graph, { 'value': ss_end, 'name': slice_node_name + '/end' }).create_node() rename_nodes([(node, slice_node_name + '_delete'), (ss, slice_node_name)]) node.in_port(0).get_connection().set_destination(ss.in_port(0)) begin_node.out_port(0).connect(ss.in_port(1)) end_node.out_port(0).connect(ss.in_port(2)) node.out_port(0).get_connection().set_source(ss.out_port(0)) else: # If Slice use more than one axis use Crop layer crop = Crop( graph, dict(axis=axis[axes], offset=begin[axes], dim=end[axes] - begin[axes])).create_node() rename_nodes([(node, slice_node_name + '_delete'), (crop, slice_node_name)]) node.in_port(0).get_connection().set_destination(crop.in_port(0)) node.out_port(0).get_connection().set_source(crop.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['op'] pair_node = Node(graph, node.pair_name) if node.t >= 0: raise Error('Does not support IfDefined with t > 0') if node.in_port(0).get_source() is not None: input_port = node.in_port(0).get_source() op_output_id = node.out_port(0).get_destination().node.id out_port = pair_node.out_port(0) node_name = node.name pair_name = pair_node.name else: input_port = pair_node.in_port(0).get_source() op_output_id = pair_node.out_port(0).get_destination().node.id out_port = node.out_port(0) node_name = pair_node.name pair_name = node.name in_shape = input_port.data.get_shape() node_t = abs(node.t) init_value_memory_out = create_zero_value_with_batch_from_input( input_port, in_shape[1] * node_t) memory_out = ReadValue(graph, { 'name': pair_name, 'variable_id': node_name + pair_name }).create_node() init_value_memory_out.out_port(0).connect(memory_out.in_port(0)) if node_t > 1: crop_concat = Crop( graph, { 'name': 'Memory_crop', 'dim': np.array([in_shape[1] * (node_t - 1)]), 'offset': np.array([in_shape[1]]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_concat.in_port(0)) concat = Concat(graph, {'name': 'Memory_concat'}).create_node() concat.add_sequence_of_ports('in', range(2)) crop_concat.out_port(0).connect(concat.in_port(0)) concat.in_port(1).connect(input_port) memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() concat.out_port(0).connect(memory_in.in_port(0)) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) crop_out = Crop( graph, { 'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]), 'offset': np.array([0]), 'axis': np.array([1]) }).create_node() memory_out.out_port(0).connect(crop_out.in_port(0)) out_port.get_connection().set_source(crop_out.out_port(0)) else: memory_in = Assign(graph, { 'name': node_name, 'variable_id': node_name + pair_name }).create_node() memory_in.in_port(0).connect(input_port) out = Result(graph, {'name': 'Memory_output'}).create_node() memory_in.out_port(0).connect(out.in_port(0)) out_port.get_connection().set_source(memory_out.out_port(0)) if not graph.graph['cmd_params'].static_shape: log.error( "Model can not be translated in a reshape-able way.\n" "Model Optimizer key static_shape was turned on to prevent related errors.\n" "There will be no success changing input shapes of the model with the help of " "InferenceEngine reshape method", extra={'is_warning': True}) graph.graph['cmd_params'].static_shape = True graph.remove_node(op_output_id) graph.remove_node(node.id) graph.remove_node(pair_node.id)
def replace_pattern(self, graph: Graph, match: dict): """ Converts specific for NasNet topology subgraph Pad->StridedSlice->AvgPool to Conv->Crop->AvgPool """ input = match['input'] pad_node = match['pad_op'] pad_node_name = pad_node.soft_get('name', pad_node.id) sslice_node = match['sslice'] begin = [] end = [] stride = [] for s in sslice_node.slices: begin.append(s.start) end.append(s.stop) stride.append(s.step) pads_begin = pad_node.in_port(1).data.get_value() pads_end = pad_node.in_port(2).data.get_value() if pads_begin is None or pads_end is None: log.error('Pad values for node "{}" are not constants'.format( pad_node_name)) return if not np.array_equal(pads_begin, int64_array([0, 0, 0, 0])): log.error('Pad begin values doesn\'t match for node {}!'.format( pad_node_name)) return if not np.array_equal(pads_end, int64_array([0, 1, 1, 0])): log.error('Pad end values doesn\'t match for node {}!'.format( pad_node_name)) return if not np.array_equal(begin, int64_array([0, 1, 1, 0])): log.error("StridedSlice has wrong begin") return if not np.array_equal(sslice_node.end_mask, int64_array( [0, 0, 0, 0])) or not np.array_equal(sslice_node.begin_mask, int64_array([0, 1, 1, 0])): log.error("StridedSlice has wrong masks") return # Pad -> Conv conv_name = graph.unique_id(pad_node.name + '/Conv_') conv_weights_name = graph.unique_id(pad_node.name + '/ConvW_') conv_weights = np.ones((input.shape[3], 1, 1, 1)) output_shape = int64_array([ input.shape[0], input.shape[1] + 1, input.shape[2] + 1, input.shape[3] ]) conv_node = Convolution( graph, dict( name=conv_name, stride=int64_array([1, 1, 1, 1]), dilation=int64_array([1, 1, 1, 1]), group=input.shape[3], bias_addable=True, bias_term=False, spatial_dims=int64_array([1, 2]), kernel_spatial=int64_array([1, 1]), pad=int64_array([[0, 0], [0, 1], [0, 1], [0, 0]]), output_shape=output_shape, batch_dims=int64_array([0]), channel_dims=int64_array([3]), output=input.shape[3], input_feature_channel=1, output_feature_channel=0, )).create_node() weights_const_node = Const( graph, dict(name=conv_weights_name, value=conv_weights, shape=int64_array(conv_weights.shape))).create_node() # StridedSlice -> Crop crop_node = Crop( graph, dict(name=sslice_node.name + '/Crop_', axis=int64_array([1, 2]), dim=int64_array([output_shape[1] - 1, output_shape[2] - 1]), offset=int64_array([1, 1]))).create_node() # Connect nodes pad_node.in_port(0).get_connection().set_destination( conv_node.in_port(0)) weights_const_node.out_port(0).connect(conv_node.in_port(1)) conv_node.out_port(0).connect(crop_node.in_port(0)) sslice_node.out_port(0).get_connection().set_source( crop_node.out_port(0)) conv_node.in_port(1).bin = 'weights' # Remove Pad and StridedSlice nodes from graph graph.remove_node(pad_node.id) graph.remove_node(sslice_node.id)
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='Memory', index=1, shape=int64_array([context_len]))), ('mem_in_data', dict()), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Memory', index=0, shape=int64_array([context_len]))), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: mem_out = Memory( graph, { 'name': 'iteration_number', 'size': 2, 'index': 1, 'id': 'iteration_' + node.name, 'shape': int64_array([context_len]), 'dst_type': np.int32 }).create_node() cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Memory( graph, { 'name': 'iteration_number_out', 'size': 2, 'index': 0, 'id': 'iteration_' + node.name, 'shape': int64_array([context_len]) }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) select_node.in_port(0).connect(input_port) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_pattern(graph: Graph, match: dict): node = match['op'] in_shape = node.in_port(0).data.get_shape().copy() memory_element = in_shape[1] - node.const_dim memory_size = memory_element * len(node.context) memory_pair_id = unique_id('id') # Memory(in) input_memory = Memory( graph, { 'name': 'prev_splice_memory', 'id': memory_pair_id, 'index': 1, 'size': 2, 'shape': int64_array([memory_size]) }).create_node() # Memory(in) \ # Crop # Input(temp) / crop = Crop( graph, { 'name': 'Splice_Crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element]), 'dim': int64_array([memory_size - memory_element]) }).create_node() crop.in_port(0).connect(input_memory.out_port(0)) # Crop \ # Concat # Input / concat_node = Concat(graph, { 'name': 'Splice_Concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node.in_port(0).connect(crop.out_port(0)) # Concat -> Memory(out) mem_out = Memory( graph, { 'name': 'out_splice_memory', 'id': memory_pair_id, 'index': 0, 'size': 2, 'shape': int64_array([memory_size]) }).create_node() mem_out.in_port(0).connect(concat_node.out_port(0)) Result(graph).create_node().in_port(0).connect(mem_out.out_port(0)) if node.const_dim != 0: memory_element_constdim = node.const_dim memory_size_constdim = memory_element_constdim * len(node.context) split = create_op_with_const_inputs( graph, VariadicSplit, { 1: int64_array(1), 2: int64_array([memory_element, memory_element_constdim]) }, { 'name': node.id + '_split_const', 'out_ports_count': 2 }) split.out_port(0).connect(concat_node.in_port(1)) # create separate splice construction for const_dim memory_pair_id = unique_id('memory_for_const_dim') input_memory_const_dim = Memory( graph, { 'name': 'const_dim_in_memory', 'id': memory_pair_id, 'index': 1, 'size': 2, 'shape': int64_array([memory_size_constdim]) }).create_node() crop_const_dim = Crop( graph, { 'name': 'const_dim_crop', 'axis': int64_array([1]), 'offset': int64_array([memory_element_constdim]), 'dim': int64_array( [memory_size_constdim - memory_element_constdim]) }).create_node() crop_const_dim.in_port(0).connect( input_memory_const_dim.out_port(0)) concat_node_const_dim = Concat(graph, { 'name': 'const_dim_concat', 'in_ports_count': 2, 'axis': 1 }).create_node() concat_node_const_dim.in_port(0).connect( crop_const_dim.out_port(0)) mem_out_const_dim = Memory( graph, { 'name': 'const_dim_out_memory', 'id': memory_pair_id, 'index': 0, 'size': 2, 'shape': int64_array([memory_size_constdim]) }).create_node() mem_out_const_dim.in_port(0).connect( concat_node_const_dim.out_port(0)) Result(graph).create_node().in_port(0).connect( mem_out_const_dim.out_port(0)) # connect splice to Split as begin and Concat as the end split.out_port(1).connect(concat_node_const_dim.in_port(1)) crop_first = Crop( graph, { 'name': 'const_dim_crop_first', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([memory_element_constdim]) }).create_node() crop_first.in_port(0).connect(concat_node_const_dim.out_port(0)) concat_const = Concat(graph, { 'name': node.id + '_concat_const', 'axis': 1, 'in_ports_count': 2 }).create_node() concat_const.in_port(1).connect(crop_first.out_port(0)) concat_const.in_port(0).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination(split.in_port(0)) node.out_port(0).get_connection().set_source( concat_const.out_port(0)) else: node.in_port(0).get_connection().set_destination( concat_node.in_port(1)) node.out_port(0).get_connection().set_source( concat_node.out_port(0)) # to avoid re-inference of shape and touching in next replacements graph.remove_node(node.id)
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.name == 'iteration_number_out': return # calculate length of context when state of inference becomes meaningful inputs = [] for n in graph.get_op_nodes(**{'op': 'Parameter'}): inputs.append(n) in_nodes = [] for inp in inputs: for ins in inp.out_port(0).get_destinations(): in_nodes.append(ins.node.name) context_len = 1 try: subgraph = invert_sub_graph_between_nodes( graph, [node.in_port(0).get_source().node.name], in_nodes) except Error: return for n in subgraph: n_node = Node(graph, n) if n_node.kind == 'op' and n_node.op == 'Splice': context_len += len(n_node.context) - 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, { 'name': 'select_' + node.name }).create_node() zero_else = Const(graph, { 'name': 'zero_else', 'value': np.zeros(in_node_shape) }).create_node() select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches( graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select'))], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', { 'in': 0 }), ('const_1', 'const_1_data'), ('const_1_data', 'concat', { 'in': 1 }), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node( graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_zero_value_with_batch_from_input( in_node_port, context_len, np.int32) mem_out = ReadValue( graph, { 'name': 'iteration_number', 'variable_id': 'iteration_' + node.name }).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop( graph, { 'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1]) }).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = Const(graph, { 'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32) }).create_node() concat = Concat(graph, { 'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1 }).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign( graph, { 'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name }).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop( graph, { 'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1]) }).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, { 'name': input_port.node.name + '/cast_to_bool' }).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def insert_select(graph: Graph, node: Node): context_len = node.frame_time + 1 if context_len == 1: return in_node_port = node.in_port(0).get_source() in_node_shape = node.in_port(0).data.get_shape() node.in_port(0).disconnect() # add Select before saving state to avoid saving garbage select_node = Select(graph, {'name': 'select_' + node.name}).create_node() zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1]) select_node.in_port(1).connect(in_node_port) select_node.in_port(2).connect(zero_else.out_port(0)) # check if we have already appropriate iteration counter existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')), ('mem_in_data', dict(shape=int64_array([context_len]))), ('crop_mem_in', dict(op='Crop', axis=int64_array([1]), offset=int64_array([1]), dim=int64_array([context_len - 1]))), ('crop_mem_in_data', dict()), ('concat', dict(op='Concat', axis=1)), ('concat_data', dict()), ('const_1', dict(op='Const')), ('const_1_data', dict()), ('mem_out', dict(op='Assign')), ('crop_out', dict(op='Crop', axis=int64_array([1]), offset=int64_array([0]), dim=int64_array([1]))), ('crop_out_data', dict()), ('select', dict(op='Select')) ], edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'), ('crop_mem_in', 'crop_mem_in_data'), ('crop_mem_in_data', 'concat', {'in': 0}), ('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}), ('concat', 'concat_data'), ('concat_data', 'mem_out'), ('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'), ('crop_out_data', 'select')]) counter_match = next(existing_counters, None) if counter_match is not None: ones = Node(graph, inverse_dict(counter_match)['const_1']) input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0) else: init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32) mem_out = ReadValue(graph, {'name': 'iteration_number', 'variable_id': 'iteration_' + node.name}).create_node() mem_out.in_port(0).connect(init_value_mem_out.out_port(0)) cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]), 'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node() cut_first.in_port(0).connect(mem_out.out_port(0)) ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32) concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node() concat.in_port(0).connect(cut_first.out_port(0)) concat.in_port(1).connect(ones.out_port(0)) mem_in = Assign(graph, {'name': 'iteration_number_out', 'variable_id': 'iteration_' + node.name}).create_node() mem_in.in_port(0).connect(concat.out_port(0)) res = Result(graph, {}).create_node() mem_in.out_port(0).connect(res.in_port(0)) cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]), 'offset': int64_array([0]), 'dim': int64_array([1])}).create_node() cut_last.in_port(0).connect(concat.out_port(0)) input_port = cut_last.out_port(0) # Check if data from memory is 1 # if it is True, we have correct data and should proceed with saving it to memory # else we have not gathered context and have garbage here, shouldn't change initial state of memory cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node() cast_in.in_port(0).connect(ones.out_port(0)) cast_in.in_port(1).connect(input_port) select_node.in_port(0).connect(cast_in.out_port(0)) select_node.out_port(0).connect(node.in_port(0)) select_node.out_port(0).data.set_shape(in_node_shape)
def replace_pattern(graph: Graph, match: dict): mem = match['op'] mem_shape = mem.in_port(0).data.get_shape() mem_parent = mem.in_port(0).get_source() context = mem['context'] for child_port in mem_parent.get_destinations(): child = child_port.node if child['op'] == 'Splice' and child.id != mem.id and \ (child['context'][0] == context[-1] or child['context'][0] == context[-1]): new_context = list(context) new_context.extend(list(child['context'])) new_context = list(set(new_context)) new_context.sort() if child['context'][0] == context[-1]: new_node = mem rem_node = child else: new_node = child rem_node = mem # reset edges from rem_node to new_node for out_port_rem in rem_node.out_port(0).get_destinations(): out_transfer = out_port_rem.node out_transfer_shape = out_port_rem.data.get_shape().copy() out_port_rem.disconnect() if out_transfer['op'] == 'Crop': # modify existing Crop to get right data from larger Splice out_transfer['offset'] = out_transfer['offset'] + ( len(new_context) - len(rem_node.context)) * mem_shape[-1] out_port_rem.connect(new_node.out_port(0)) else: # insert Crop if we have not one crop_node = Crop( graph, { 'name': graph.unique_id(prefix='Splice_crop_'), 'offset': (len(new_context) - len(rem_node.context)) * mem_shape[-1], 'dim': np.array([ len(rem_node['context']) * mem_shape[-1] ]), 'axis': np.array([-1]) }).create_node() new_node.out_port(0).connect(crop_node.in_port(0)) crop_node.out_port(0).connect(out_port_rem) crop_node.out_port(0).data.set_shape( out_transfer_shape) for out_port_rem in new_node.out_port(0).get_destinations(): out_transfer = out_port_rem.node out_transfer_shape = out_port_rem.data.get_shape().copy() if out_transfer['op'] != 'Crop': # insert Crop if we have not one crop_node = Crop( graph, { 'name': graph.unique_id(prefix='Splice_crop_'), 'offset': np.array([0]), 'dim': np.array([ len(new_node['context']) * mem_shape[-1] ]), 'axis': np.array([-1]) }).create_node() new_node.out_port(0).connect(crop_node.in_port(0)) out_port_rem.disconnect() crop_node.out_port(0).connect(out_port_rem) crop_node.out_port(0).data.set_shape( out_transfer_shape) new_shape = new_node.out_port(0).data.get_shape() new_shape[1] += rem_node.out_port(0).data.get_shape( )[1] - rem_node.in_port(0).data.get_shape()[1] new_node.out_port(0).data.set_shape(new_shape) new_node.context = new_context graph.remove_node(rem_node.id)