def find_and_replace_pattern(self, graph: Graph): for gather in graph.get_op_nodes(type='Gather'): indices = gather.in_port(1).get_source().node indices_value = gather.in_port(1).data.get_value() if indices.op == 'Const' and indices_value is not None and indices_value.ndim == 0: log.debug( 'The Gather node {} has constant 0D input with indices'. format(gather.id)) new_indices = Const(graph, { 'value': np.array([indices_value.item()]) }).create_node() # the input shape is changed so need to disconnect port first gather.in_port(1).disconnect() gather.in_port(1).connect(new_indices.out_port(0)) # the output of Gather is changed so need to run shape inference for it and override the existing shape gather['override_output_shape'] = True gather['need_shape_inference'] = True # insert Squeeze to remove the dimension 'axis' which become equal to 1 after change of the Gather # indices constant squeeze = Squeeze(graph, { 'name': gather.id + '/Squeeze' }).create_node() squeeze_axis = Const( graph, { 'name': squeeze.id + '/axis', 'value': int64_array([gather.axis]) }).create_node() gather.out_port(0).get_connection().insert_node(squeeze) squeeze.in_port(1).connect(squeeze_axis.out_port(0))
def test_squeeze_squeeze_dims(self, input_value, input_shape, squeeze_dims, ref_value, ref_shape): graph = build_graph( nodes_attributes, [('data', 'squeeze'), ('squeeze_dims', 'squeeze_dims_data'), ('squeeze_dims_data', 'squeeze'), ('squeeze', 'data_out')], { 'data': { 'shape': input_shape, 'value': input_value }, 'squeeze_dims': { 'value': squeeze_dims, 'shape': squeeze_dims.shape }, 'squeeze_dims_data': { 'value': squeeze_dims, 'shape': squeeze_dims.shape }, }) node = Node(graph, 'squeeze') if ref_shape is None: # the test should fail with self.assertRaises(Error): Squeeze.infer(node) else: Squeeze.infer(node) if ref_value is not None: self.assertTrue( strict_compare_tensors( node.out_port(0).data.get_value(), ref_value)) self.assertTrue( strict_compare_tensors( node.out_port(0).data.get_shape(), ref_shape))
def extract(node): attrs = get_mxnet_layer_attrs(node.symbol_dict) Squeeze.update_node_stat(node, { 'squeeze_dims': attrs.int("axis", None), 'keep_at_least_1d': True }) return __class__.enabled
def extract(node): axis = np.array(onnx_attr(node, 'axes', 'ints', default=[]), dtype=np.int64) attrs = {'squeeze_dims': axis if len(axis) != 0 else None} # update the attributes of the node Squeeze.update_node_stat(node, attrs) return __class__.enabled
def replace_op(self, graph: Graph, node: Node): squeeze_op = Squeeze(graph, dict()) squeeze_op.attrs['old_infer'] = squeeze_op.attrs['infer'] squeeze_op.attrs['infer'] = __class__.do_infer squeeze_node = squeeze_op.create_node([], dict(name=node.name + '/Squeeze')) node.insert_node_after(squeeze_node) return []
def add_squeeze_for_shrink(graph: Graph, ss_node: Node): # add Squeeze for shrink_axis_mask log.info( "StridedSlice op with shrink mask '{}' has been detected".format( ss_node.id)) if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1: return shape_out = ss_node.out_node().shape dim = np.array(range(len(ss_node['shrink_axis_mask'])))[np.array( ss_node['shrink_axis_mask'], dtype=bool)] ss_shape = [] i = 0 k = 0 # Don't permute reshape if channels were squeezed dont_permute = graph.graph['layout'] == 'NCHW' if graph.graph['layout'] == 'NHWC' and ss_node['shrink_axis_mask'][ -1] == 1: dont_permute = True while k < len(shape_out): if i >= len(ss_node['shrink_axis_mask'] ) or not ss_node['shrink_axis_mask'][i]: ss_shape.append(shape_out[k]) k = k + 1 else: ss_node['shrink_axis_mask'][i] = 0 ss_shape.append(1) i = i + 1 while i < len(ss_node['shrink_axis_mask']): ss_node['shrink_axis_mask'][i] = 0 ss_shape.append(1) i = i + 1 ss_node.out_port(0).data.set_shape(ss_shape) # insert Squeeze squeeze_node = Squeeze( graph, dict(name=ss_node.name + '/Squeeze_shrink', nchw_layout=dont_permute, correct_data_layout=dont_permute)).create_node() ss_node.out_port(0).get_connection().insert_node(squeeze_node) squeeze_node.out_port(0).data.set_shape(shape_out) dims_node = Const(graph, { 'name': squeeze_node.id + '/Indices', 'value': int64_array(dim) }).create_node() dims_node.out_port(0).connect(squeeze_node.in_port(1))
def replace_op(self, graph: Graph, node: Node): for out_port in node.out_ports().values(): squeeze_node = Squeeze(graph, dict(name=node.name + '/Squeeze_')).create_node([]) dims_node = Const(graph, { 'value': np.array(node.axis), 'name': node.name + '/Squeeze_axis' }).create_node() out_port.get_connection().insert_node(squeeze_node) dims_node.out_port(0).connect(squeeze_node.in_port(1)) # do not replace any output edge return []
def replace_sub_graph(self, graph: Graph, match: dict): """ In ONNX ArgMax operation has keepdims attribute that indicates whether to stay a dimension along which maximum is computed or not. In case of keepdims=0 this dimension should be removed but ArgMax operation in IR format is not designed to cover this case. So we should additionally add Squeeze operation right after ArgMax for this case. """ argmax_node = match['argmax'] axis = argmax_node.axis squeeze_node = Squeeze(graph, {'squeeze_dims': [axis]}).create_node() argmax_node.out_port(0).get_connection().set_source( squeeze_node.out_port(0)) squeeze_node.in_port(0).connect(argmax_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): """ Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors): Searches for Tiles with 3D shapes and covers it with Reshapes. Example: Tile (axis=1, tiles=16): in_shape: [1,1,101] out_shape: [1,16,101] Old behaviour: Tile -> [1,16,101] New behaviour: Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101] """ node = match['tile'] name = node.soft_get('name', node.id) out_shape = node.out_port(0).data.get_shape() assert out_shape is not None, 'Output shape is undefined for {} in back phase'.format(name) if out_shape.size != 3: return inp_shape = node.in_port(0).data.get_shape() assert inp_shape is not None, 'Input shape is undefined for {} in back phase'.format(name) unsqueeze_dim = Const(graph, {'name': name + '/3D_Tile_Unsqueeze_dim', 'value': int64_array([3])}).create_node() unsqueeze = Unsqueeze(graph, {'name': name + '/3D_Tile_Unsqueeze', 'override_output_shape': True}).create_node() unsqueeze_dim.out_port(0).connect(unsqueeze.in_port(1)) const = Const(graph, {'name': name + '/additional_axis', 'value': int64_array([1])}).create_node() new_tiles = new_shape_node_from_shape_nodes([node.in_port(1).get_source().node, const]) node.in_port(1).get_connection().set_source(new_tiles.out_port(0)) squeeze_dim = Const(graph, {'name': name + '/3D_Tile_Squeeze_dim', 'value': int64_array([3])}).create_node() squeeze = Squeeze(graph, {'name': name + '/3D_Tile_Squeeze', 'override_output_shape': True}).create_node() squeeze_dim.out_port(0).connect(squeeze.in_port(1)) source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(source) node.out_port(0).get_connection().set_source(squeeze.out_port(0)) node.out_port(0).connect(squeeze.in_port(0)) node['override_output_shape'] = True new_tiles['override_output_shape'] = True node['need_shape_inference'] = True
def replace_op(self, graph: nx.MultiDiGraph, node: Node): for ind in range(len(node.out_nodes())): squeeze_node = Squeeze(graph, dict(squeeze_dims=[node.axis], name=node.name + '/Squeeze_')).create_node([]) insert_node_after(node, squeeze_node, ind) # do not replace any output edge return []
def replace_sub_graph(self, graph: Graph, match: dict): node = match['argmax'] connected_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] squeeze_node = Squeeze(graph, dict()).create_node([], dict(name=node.name + '/Squeeze')) if len(connected_ports) == 2: node.in_port(1).get_source().connect(squeeze_node.in_port(1)) else: axis_node = Const(graph, {'value': node.axis}).create_node() node.in_port(1).connect(axis_node.out_port(0)) node.out_port(0).get_connection().set_source(squeeze_node.out_port(0)) node.out_port(0).connect(squeeze_node.in_port(0)) return []
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(squeeze_axis=True): name = node.soft_get('name', node.id) for out_port in node.out_ports().values(): if node.has_valid('axis'): squeeze_node = create_op_with_const_inputs( graph, Squeeze, {1: np.array(node.axis)}, {'name': name + '/Squeeze_'}) out_port.get_connection().insert_node(squeeze_node) elif node.is_in_port_connected(1): squeeze_node = Squeeze(graph, { 'name': name + '/Squeeze_' }).create_node() out_port.get_connection().insert_node(squeeze_node) node.in_port(1).get_connection().add_destination( squeeze_node.in_port(1)) else: raise Error( 'Unknown axis to squeeze for node {}'.format(name))
def test_squeeze_empty_squeeze_dims(self): graph = build_graph( nodes_attributes, [('data', 'squeeze'), ('squeeze_dims', 'squeeze_dims_data'), ('squeeze_dims_data', 'squeeze'), ('squeeze', 'data_out')], { 'data': { 'shape': np.array([1, 2, 1, 4]) }, 'squeeze_dims': { 'value': np.array([]), 'shape': np.array([1]) }, 'squeeze_dims_data': { 'value': np.array([]), 'shape': np.array([1]) }, }) node = Node(graph, 'squeeze') Squeeze.infer(node) self.assertTrue(np.all(node.out_port(0).data.get_shape() == [2, 4]))
def replace_op(self, graph: Graph, node: Node): if node.module.inverse: axes = Const( graph, { 'value': int64_array(range(2, node.module.num_axes - 1)) }).create_node() dft_node = IDFT(graph, dict(name=node.name, in_ports_count=2)).create_node( [node.in_node(0), axes]) # Slice a real part begin_id = Const(graph, { 'value': int64_array([0, 0]) }).create_node() end_id = Const(graph, {'value': int64_array([0, 1])}).create_node() real = StridedSlice( graph, dict(name=node.name + '/real', begin_mask=[0, 0], end_mask=[0, 1], shrink_axis_mask=[0, 0], new_axis_mask=[0], ellipsis_mask=[1, 0])).create_node( [dft_node, begin_id, end_id]) squeeze_axis = Const(graph, {'value': -1}).create_node() res = Squeeze(graph, dict(name=node.name + '/squeeze')).create_node( [real, squeeze_axis]) return [res.id] else: zero = Const(graph, {'value': 0.0}).create_node() imag = Mul(graph, dict(name=node.name + '/imag')).create_node( [node.in_node(0), zero]) cmplx = PackOp(graph, dict(name=node.name + '/complex', axis=-1)).create_node([node.in_node(0), imag]) axes = Const(graph, { 'value': int64_array(range(2, node.module.num_axes)) }).create_node() dft_node = DFT(graph, dict(name=node.name, in_ports_count=2)).create_node([cmplx, axes]) return [dft_node.id]
def replace_pattern(self, graph: Graph, match: dict): if match['rnn_layer']['op'] == 'LSTM': return rnn_layer = match['rnn_layer'] # Build TensorIterator body first body = Graph(name=rnn_layer.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [ Op._create_data_node( body, rnn_layer.name + '/inport/' + str(inp), { 'shape': rnn_layer.in_node(inp).shape.copy(), 'value': rnn_layer.in_node(inp).value.copy() if rnn_layer.in_node(inp).value is not None and inp in [1, 2] else None }) for inp in [0, 4, 1, 2] ] # X, h_init, WR, B inputs[0].shape[rnn_layer.sequence_dim] = 1 input_squeeze = Squeeze( body, dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0)) input_squeeze_dim = Const( body, dict(name=rnn_layer.name + '/input_squeeze_dim', value=rnn_layer.sequence_dim)).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data( [inputs[0], input_squeeze_dim], edge_attrs=[{ 'internal_port_id': 0 }]) # 2. Output unsqueeze Reshape outputs = [ Op._create_data_node( body, rnn_layer.name + '/outport/' + str(out), { 'shape': rnn_layer.out_node(out).shape.copy() if out in rnn_layer.out_nodes() else None }) for out in [0] ] for out in outputs: add_opoutput(body, out.id, 0, False) outputs[0].shape = np.delete(outputs[0].shape.copy(), rnn_layer.sequence_dim) output_unsqueeze_dim = Const( body, dict(name=rnn_layer.name + '/output_unsqueeze_dim', value=rnn_layer.sequence_dim)).create_node_with_data() output_unsqueeze = Unsqueeze( body, dict(name=rnn_layer.name + '/output_unsqueeze/', internal_layer_id=2)) additional_attrs = dict(activations=rnn_layer.activations, activation_alpha=rnn_layer.activation_alpha, activation_beta=rnn_layer.activation_beta, clip=rnn_layer.clip) if rnn_layer.op == 'GRU': additional_attrs[ 'linear_before_reset'] = rnn_layer.linear_before_reset # 3. ***Cell rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])( body, dict(hidden_size=rnn_layer.hidden_size, name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op), **additional_attrs, internal_layer_id=1)) gru_cell = rnn_cell_op.create_node_with_data(inputs, data_nodes=outputs, edge_attrs=[{}, { 'internal_port_id': 1 }, { 'internal_port_id': 2 }, { 'bin': 'weights' }, { 'bin': 'biases' }]) # internal ports for outputs of cell gru_cell.in_node().out_edge(0)['internal_port_id'] = 4 # h_state gru_cell = output_unsqueeze.create_node_with_data( [gru_cell, output_unsqueeze_dim]) gru_cell.in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, gru_cell.id, 0, False) # 4. TensorIterator layer creating assert rnn_layer.direction in ['forward', 'reverse'] if rnn_layer.direction == 'forward': stride = 1 start = None end = None else: assert rnn_layer.direction == 'reverse' stride = -1 start = -1 end = 0 # stacked h_state output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': rnn_layer.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding last h_state to outputs if len(rnn_layer.out_nodes()) == 2: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }]) ti_op = TensorIterator( graph, { 'name': rnn_layer.name + '/TensorIterator', 'body': body, 'in_ports_count': 4, 'out_ports_count': len(rnn_layer.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': rnn_layer.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, ], 'output_port_map': output_port_map, # only for h state 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, ] }) assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \ "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id) outs = ti_op.create_node_with_data( [rnn_layer.in_node(i) for i in [0, 4]], # X, h_init data_nodes=[ rnn_layer.out_node(i) for i in range(len(rnn_layer.out_nodes())) ], edge_attrs=[{ 'external_port_id': 0 }, { 'external_port_id': 1 }]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(rnn_layer.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id ti = outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)
def replace_pattern(self, graph: Graph, match: dict): lstm = match['lstm'] # Build TensorIterator body first body = Graph(name=lstm.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [ Op._create_data_node( body, lstm.name + '/inport/' + str(inp), { 'shape': lstm.in_node(inp).shape.copy(), 'value': lstm.in_node(inp).value.copy() if lstm.in_node(inp).value is not None and inp in [1, 2] else None }) for inp in [0, 4, 5, 1, 2] ] # X, WR, B, h_init, c_init inputs[0].shape[lstm.sequence_dim] = 1 input_squeeze = Squeeze( body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0)) squeeze_dim_data = Const(body, { 'name': lstm.name + '/input_squeeze_dim', 'value': [lstm.sequence_dim] }).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data( [inputs[0], squeeze_dim_data], edge_attrs=[{ 'internal_port_id': 0 }]) # 2. Output unsqueeze Reshape outputs = [ Op._create_data_node( body, lstm.name + '/outport/' + str(out), { 'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes() else lstm.in_node(4).shape.copy() }) for out in [0, 1] ] for out in outputs: add_opoutput(body, out.id, 0, False) outputs[0].shape = shape_delete(outputs[0].shape, lstm.sequence_dim) output_unsqueeze = Unsqueeze( body, dict(name=lstm.name + 'output_unsqueeze', internal_layer_id=2)) unsqueeze_dim_data = Const( body, { 'name': lstm.name + '/output_unsqueeze_dim', 'value': [lstm.sequence_dim] }).create_node_with_data() # 3. LSTMCell lstm_cell_op = LSTMCell( body, dict(hidden_size=lstm.hidden_size, activations=lstm.activations, activation_alpha=lstm.activation_alpha, activation_beta=lstm.activation_beta, clip=lstm.clip, input_forget=lstm.input_forget, name=lstm.name + '/LSTMCell', internal_layer_id=1)) lstm_cell_node = lstm_cell_op.create_node_with_data( inputs, data_nodes=outputs, edge_attrs=[{}, { 'internal_port_id': 1 }, { 'internal_port_id': 2 }, { 'bin': 'weights' }, { 'bin': 'biases' }]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4 lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5 lstm_cell_node[0] = output_unsqueeze.create_node_with_data( [lstm_cell_node[0], unsqueeze_dim_data]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, lstm_cell_node[0].id, 0, False) # 4. TensorIterator layer creating assert lstm.direction in ['forward', 'reverse'] if lstm.direction == 'forward': stride = 1 start = None end = None else: assert lstm.direction == 'reverse' stride = -1 start = -1 end = 0 output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding h_state, c_state to outputs if len(lstm.out_nodes()) == 3: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }, { 'external_port_id': 5, 'internal_layer_id': 1, 'internal_port_id': 5, }]) ti_op = TensorIterator( graph, { 'name': lstm.name + '/TensorIterator', 'body': body, 'in_ports_count': 3, 'out_ports_count': len(lstm.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, { 'external_port_id': 2, 'internal_layer_id': 1, 'internal_port_id': 2, }, ], 'output_port_map': output_port_map, 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, { 'from_layer': 1, 'from_port': 5, 'to_layer': 1, 'to_port': 2, }, ] }) assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \ "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id) outs = ti_op.create_node_with_data( [lstm.in_node(i) for i in [0, 4, 5]], # X, h_init, c_init data_nodes=[ lstm.out_node(i) for i in range(len(lstm.out_nodes())) ], edge_attrs=[{ 'external_port_id': 0 }, { 'external_port_id': 1 }, { 'external_port_id': 2 }]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(lstm.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id ti = outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)
def extract(cls, node: Node): Squeeze.update_node_stat( node, {'squeeze_dims': tf_int_list(node.pb.attr['squeeze_dims'].list)}) return cls.enabled
def replace_pattern(graph, match: dict): # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op # and make all checks needed for TensorIterator work cond_data = match['condition'].out_node( 0) if not match['condition'].out_port(0).disconnected() else None time_data = match['condition'].out_node(1) if len( match['condition'].out_nodes()) >= 1 else None name = match['condition'].name back_edges = [] inputs = [] outputs = [] if cond_data is not None: for node in cond_data.out_nodes(): if node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorBackEdge': back_edges.append(node.id) elif node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorInput': inputs.append(node.id) elif node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorOutput': outputs.append(node.id) if time_data is not None: for node in time_data.out_nodes(): if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput': inputs.append(node.id) elif node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorOutput': outputs.append(node.id) else: # something goes wrong here assert False condition = match['condition'] tensor_sequence_length = condition.in_node(0) nodes_to_remove = [ n.id for n in (condition, cond_data, time_data, tensor_sequence_length) if n is not None ] graph.remove_nodes_from(nodes_to_remove) body_nodes, extra_inputs = get_body(graph, inputs, outputs) if cond_data is not None: body_nodes = list(set(body_nodes) - set([cond_data])) inputs += extra_inputs assert all([node in graph.nodes() for node in body_nodes]) inputs = [Node(graph, node) for node in inputs] outputs = [Node(graph, node) for node in outputs] back_edges = [Node(graph, node) for node in back_edges] external_inputs = [{ 'external_data_id': node.in_node(1 if node.has_valid('axis') else 0), 'internal_data_id': node.out_node(0), 'axis': node.axis, 'start': node.start, 'end': node.end, 'stride': node.stride, 'part_size': node.part_size } for node in inputs] external_outputs = [{ 'external_data_id': node.out_node(0), 'internal_data_id': node.in_node(1 if node.has_valid('axis') else 0), 'axis': node.axis, 'start': node.start, 'end': node.end, 'stride': node.stride, 'part_size': node.part_size } for node in outputs] back_edges_data = [{ 'from_data_id': node.in_node(1), 'to_data_id': node.out_node(0), 'init_data_id': node.in_node(0), } for node in back_edges] body = Graph(name='body') body.graph = graph.graph body.add_nodes_from([(node, graph.node[node]) for node in body_nodes]) body.add_edges_from([ (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes ]) graph.remove_nodes_from(body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs]) internal_id_count = 0 real_back_edges = [] for edge in back_edges_data: assert edge['from_data_id'].id in body.nodes() assert edge['to_data_id'].id in body.nodes() assert edge['init_data_id'].id in body.nodes() edge['from_data_id'] = Node(body, edge['from_data_id'].id) edge['to_data_id'] = Node(body, edge['to_data_id'].id) edge['init_data_id'] = Node(body, edge['init_data_id'].id) add_opoutput(body, edge['from_data_id'].id, 0, False) # Assign/reuse ids for the back-edge start; it comes from from_data_id assert len(edge['from_data_id'].in_nodes()) == 1 # layer id if not edge['from_data_id'].in_node().has_valid( 'internal_layer_id'): edge['from_data_id'].in_node( )['internal_layer_id'] = internal_id_count internal_id_count += 1 edge['from_layer'] = edge['from_data_id'].in_node( )['internal_layer_id'] # port id if 'internal_port_id' not in edge['from_data_id'].in_edge(): edge['from_data_id'].in_edge( )['internal_port_id'] = internal_id_count internal_id_count += 1 edge['from_port'] = edge['from_data_id'].in_edge( )['internal_port_id'] # Look at all consumers for a data that ends a back-edge # For each such consumer, there will be a separate back-edge (and input) current_real_back_edges = [] for _, consumer, key, edge_attrs in body.out_edges( edge['to_data_id'].id, data=True, keys=True): real_edge = {} real_edge.update( edge) # all real back_edges have the same back-edge start consumer = Node(body, consumer) if real_edge['to_data_id'].in_node().has_valid( 'internal_layer_id'): assert False real_edge['to_data_id'].out_node()['internal_layer_id'] = \ real_edge['to_data_id'].in_node().internal_layer_id elif not consumer.has_valid('internal_layer_id'): consumer['internal_layer_id'] = internal_id_count internal_id_count += 1 real_edge['to_layer'] = consumer['internal_layer_id'] assert 'internal_port_id' not in edge_attrs assert len(real_edge['init_data_id'].out_edges()) == 1 assert not 'internal_port_id' in real_edge[ 'init_data_id'].out_edge() edge_attrs['internal_port_id'] = internal_id_count internal_id_count += 1 real_edge['to_port'] = edge_attrs['internal_port_id'] real_edge['consumer'] = consumer real_edge['consumer_key'] = key real_edge['attrs'] = deepcopy(edge_attrs) current_real_back_edges.append(real_edge) # connect initial data node with each consumer providing actual edge attributes body.add_edges_from([ (real_edge['init_data_id'].id, real_edge['consumer'].id, real_edge['consumer_key'], real_edge['attrs']) for real_edge in current_real_back_edges ]) body.remove_nodes_from( [edge['to_data_id'].id, edge['to_data_id'].in_node().id]) real_back_edges += current_real_back_edges real_external_inputs = [] for ext_inp in external_inputs: assert ext_inp['external_data_id'].id not in body.nodes() assert ext_inp['internal_data_id'].id in body.nodes() ext_inp['internal_data_id'] = Node(body, ext_inp['internal_data_id'].id) if ext_inp['axis'] is not None: # Insert squeezing resize at input port that has partitioning shape = ext_inp['internal_data_id'].shape.copy() assert not ext_inp['internal_data_id'].has_valid('value') new_input_data = Op._create_data_node( body, ext_inp['internal_data_id'].name + '/UnsqueezedInput', dict(shape=shape_insert(shape, ext_inp['axis'], 1))) reshape_op = Squeeze( body, dict(name=ext_inp['internal_data_id'].name + '/InputSqueeze')) reshape_dim_data = Const( body, { 'name': ext_inp['internal_data_id'].name + '/ReshapeDim', 'value': ext_inp['axis'] }).create_node_with_data() reshape_op.create_node_with_data( [new_input_data, reshape_dim_data], data_nodes=[ext_inp['internal_data_id']]) ext_inp['internal_data_id'] = new_input_data ext_inp['internal_data_id']['is_input'] = True assert len(ext_inp['internal_data_id'].in_nodes()) == 0 ext_inp['external_port_id'] = internal_id_count internal_id_count += 1 for _, consumer, edge_attrs in body.out_edges( ext_inp['internal_data_id'].id, data=True): real_ext_inp = {} real_ext_inp.update(ext_inp) consumer = Node(body, consumer) if not consumer.has_valid('internal_layer_id'): consumer['internal_layer_id'] = internal_id_count internal_id_count += 1 if not 'internal_port_id' in edge_attrs: edge_attrs['internal_port_id'] = internal_id_count internal_id_count += 1 real_ext_inp['internal_layer_id'] = consumer[ 'internal_layer_id'] real_ext_inp['internal_port_id'] = edge_attrs[ 'internal_port_id'] real_external_inputs.append(real_ext_inp) for ext_out in external_outputs: assert ext_out['external_data_id'].id not in body.nodes() assert ext_out['internal_data_id'].id in body.nodes() ext_out['internal_data_id'] = Node(body, ext_out['internal_data_id'].id) if ext_out['axis'] is not None: # Insert unsqueezing resize at output port that has partitioning reshape_op = Unsqueeze( body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze')) reshape_dim_data = Const( body, { 'name': ext_out['internal_data_id'].name + '/ReshapeDim', 'value': ext_out['axis'] }).create_node_with_data() ext_out['internal_data_id'] = reshape_op.create_node_with_data( [ext_out['internal_data_id'], reshape_dim_data]) # TODO: add here working with simple outputs if not any([ out_node.soft_get('op', None) == 'Result' for out_node in ext_out['internal_data_id'].out_nodes() ]): add_opoutput(body, ext_out['internal_data_id'].id, 0, False) # assert len(ext_out['internal_data_id'].out_nodes()) == 0 assert len(ext_out['internal_data_id'].in_nodes()) == 1 if not 'internal_layer_id' in ext_out['internal_data_id'].in_node( ): ext_out['internal_data_id'].in_node( )['internal_layer_id'] = internal_id_count internal_id_count += 1 if not 'internal_port_id' in ext_out['internal_data_id'].in_edge(): ext_out['internal_data_id'].in_edge( )['internal_port_id'] = internal_id_count internal_id_count += 1 ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node( )['internal_layer_id'] ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge( )['internal_port_id'] ext_out['external_port_id'] = internal_id_count internal_id_count += 1 # create TensorIterator layer with pre-computed components ti_op = TensorIterator( graph, { 'name': name + '/TensorIterator', 'body': body, 'in_ports_count': len(external_inputs), 'out_ports_count': len(external_outputs), 'input_port_map': [{ field: external_input[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end' ] } for external_input in real_external_inputs], 'output_port_map': [{ field: external_output[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end' ] } for external_output in external_outputs], 'back_edges': [{ field: edge[field] for field in ['from_layer', 'from_port', 'to_layer', 'to_port'] } for edge in real_back_edges], }) ti_outs = ti_op.create_node_with_data( inputs=[inp['external_data_id'] for inp in external_inputs], edge_attrs=[{ 'external_port_id': inp['external_port_id'] } for inp in external_inputs], data_nodes=[out['external_data_id'] for out in external_outputs]) if not isinstance(ti_outs, list): ti_outs = [ti_outs] for i, out in enumerate(ti_outs): out.in_edge( )['external_port_id'] = external_outputs[i]['external_port_id'] ti = ti_outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)