def find_and_replace_pattern(self, graph: Graph): # Iterate over all data nodes and find all with >= 1 consumers data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data'] for input_data in data_nodes: # We don't use constant data nodes if input_data.value is not None: continue input_shape = np.array(input_data.shape) # Get all StridedSlice consumers out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice' and node.in_node(0).name == input_data.name] if len(out_nodes) < 1: continue valid_for_replacement = True for node in out_nodes: if len(node.slices) != len(out_nodes[0].slices): valid_for_replacement = False # Detect dimension for splitting split_channel_dim = None for dim_id, s in enumerate(out_nodes[0].slices): l, r, stride = s.start, s.stop, s.step if l != 0 or r != input_shape[dim_id]: if split_channel_dim is None: split_channel_dim = dim_id else: valid_for_replacement = False # split_dims contains tuples with split range and output data node split_dims = [] for out_id, node in enumerate(out_nodes): # Check that StridedSlice op has stride eq 1 and splits only feature channel for id, s in enumerate(node.slices): l, r, stride = s.start, s.stop, s.step # We don't support StridedSlice with stride != 1 if stride != 1: valid_for_replacement = False if id == split_channel_dim: split_dims.append((s.start, s.stop, node.out_node())) if not valid_for_replacement: continue # Check feature split intersection final_data_nodes_list = [] sorted_split_dims = sorted(split_dims, key=lambda item: (item[0], item[1])) # check if we have similar StridedSlice operations with different outputs prev_sd = sorted_split_dims[0] to_remove = [] for i in range(1, len(sorted_split_dims)): if sorted_split_dims[i][0] == prev_sd[0] and sorted_split_dims[i][1] == prev_sd[1] and sorted_split_dims[i][2].name != prev_sd[2].name: cur_node = sorted_split_dims[i][2] for out in cur_node.out_nodes(): attrs = deepcopy(graph.get_edge_data(cur_node.id, out.id)[0]) graph.remove_edge(cur_node.id, out.id) graph.add_edge(prev_sd[2].id, out.id, **attrs) to_remove.append(i) for ind in reversed(to_remove): sorted_split_dims.pop(ind) size_splits = [] prev_r = 0 for l, r, out in sorted_split_dims: # Split dims shouldn't intersect if l < prev_r: valid_for_replacement = False # Save missing tensor part if l > prev_r: shape = np.array(input_shape) size_splits.append(l - prev_r) shape[split_channel_dim] = l - prev_r data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape}) add_opoutput(graph, data_node.id, 0, False) final_data_nodes_list.append(data_node) prev_r = r size_splits.append(r - l) final_data_nodes_list.append(out) if prev_r > input_shape[split_channel_dim]: valid_for_replacement = False elif prev_r != input_shape[split_channel_dim]: # Add last part of tensor shape = input_shape.copy() shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r size_splits.append(input_shape[split_channel_dim] - prev_r) data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape}) add_opoutput(graph, data_node.id, 0, False) final_data_nodes_list.append(data_node) if not valid_for_replacement: continue for node in out_nodes: if not np.all([x == 0 for x in node.shrink_axis_mask]): out_node = node.out_node() if np.any(node['shrink_axis_mask']): self.add_reshape_for_shrink(graph, node) if np.any(node['new_axis_mask']): self.add_reshape_for_new(graph, node) for i in range(len(final_data_nodes_list)): if final_data_nodes_list[i].name == out_node.name: final_data_nodes_list[i] = node.out_node() break # Insert Split layer and remove old StridedSlice layers # 1. Remove connections from input_data to StridedSlice ops out_data_nodes = [] name_for_future_split = out_nodes[0].name for node in out_nodes: out_data_nodes.append(node.out_node()) graph.remove_edge(input_data.id, node.id) graph.remove_edge(node.id, node.out_node().id) graph.remove_node(node.id) log.debug("Removed: {}".format(node.id)) # 2. Create Split layer and reorder outputs split = SplitV(graph, dict(name=name_for_future_split + "/Split", axis=split_channel_dim, size_splits=size_splits, out_ports_count=len(size_splits))) split.create_node_with_data(inputs=[input_data], data_nodes=final_data_nodes_list)
def replace_pattern(self, graph: Graph, match: dict): if match['rnn_layer']['op'] == 'LSTM': return rnn_layer = match['rnn_layer'] # Build TensorIterator body first body = Graph(name=rnn_layer.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [ Op._create_data_node( body, rnn_layer.name + '/inport/' + str(inp), { 'shape': rnn_layer.in_node(inp).shape.copy(), 'value': rnn_layer.in_node(inp).value.copy() if rnn_layer.in_node(inp).value is not None and inp in [1, 2] else None }) for inp in [0, 4, 1, 2] ] # X, h_init, WR, B inputs[0].shape[rnn_layer.sequence_dim] = 1 input_squeeze = Squeeze( body, dict(name=rnn_layer.name + '/input_squeeze', internal_layer_id=0)) input_squeeze_dim = Const( body, dict(name=rnn_layer.name + '/input_squeeze_dim', value=rnn_layer.sequence_dim)).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data( [inputs[0], input_squeeze_dim], edge_attrs=[{ 'internal_port_id': 0 }]) # 2. Output unsqueeze Reshape outputs = [ Op._create_data_node( body, rnn_layer.name + '/outport/' + str(out), { 'shape': rnn_layer.out_node(out).shape.copy() if out in rnn_layer.out_nodes() else None }) for out in [0] ] for out in outputs: add_opoutput(body, out.id, 0, False) outputs[0].shape = shape_delete(outputs[0].shape, rnn_layer.sequence_dim) output_unsqueeze_dim = Const( body, dict(name=rnn_layer.name + '/output_unsqueeze_dim', value=rnn_layer.sequence_dim)).create_node_with_data() output_unsqueeze = Unsqueeze( body, dict(name=rnn_layer.name + '/output_unsqueeze/', internal_layer_id=2)) additional_attrs = dict(activations=rnn_layer.activations, activation_alpha=rnn_layer.activation_alpha, activation_beta=rnn_layer.activation_beta, clip=rnn_layer.clip) if rnn_layer.op == 'GRU': additional_attrs[ 'linear_before_reset'] = rnn_layer.linear_before_reset # 3. ***Cell rnn_cell_op = self.get_rnn_cell(rnn_layer['op'])( body, dict(hidden_size=rnn_layer.hidden_size, name=rnn_layer.name + '/{}Cell'.format(rnn_layer.op), **additional_attrs, internal_layer_id=1)) gru_cell = rnn_cell_op.create_node_with_data(inputs, data_nodes=outputs, edge_attrs=[{}, { 'internal_port_id': 1 }, { 'internal_port_id': 2 }, { 'bin': 'weights' }, { 'bin': 'biases' }]) # internal ports for outputs of cell gru_cell.in_node().out_edge(0)['internal_port_id'] = 4 # h_state gru_cell = output_unsqueeze.create_node_with_data( [gru_cell, output_unsqueeze_dim]) gru_cell.in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, gru_cell.id, 0, False) # 4. TensorIterator layer creating assert rnn_layer.direction in ['forward', 'reverse'] if rnn_layer.direction == 'forward': stride = 1 start = None end = None else: assert rnn_layer.direction == 'reverse' stride = -1 start = -1 end = 0 # stacked h_state output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': rnn_layer.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding last h_state to outputs if len(rnn_layer.out_nodes()) == 2: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }]) ti_op = TensorIterator( graph, { 'name': rnn_layer.name + '/TensorIterator', 'body': body, 'in_ports_count': 4, 'out_ports_count': len(rnn_layer.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': rnn_layer.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, ], 'output_port_map': output_port_map, # only for h state 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, ] }) assert sorted(rnn_layer.out_nodes().keys()) == list(range(len(rnn_layer.out_nodes()))), \ "There are gaps in output ports of GRUSequence operation. Node {}".format(rnn_layer.id) outs = ti_op.create_node_with_data( [rnn_layer.in_node(i) for i in [0, 4]], # X, h_init data_nodes=[ rnn_layer.out_node(i) for i in range(len(rnn_layer.out_nodes())) ], edge_attrs=[{ 'external_port_id': 0 }, { 'external_port_id': 1 }]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(rnn_layer.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id ti = outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)
def split_bidirectional(self, bidirectional_cell: Node, new_init_hiddens: list, new_init_cells: list, splitted_W: tuple, splitted_R: tuple, splitted_B: tuple): """ Split one bidirectional RNNSequence node into 2 one-directional RNNSequence nodes. All input data nodes should be already prepared; they are have 2 in the num_dir dimension. """ all_outputs = [] for i in [0, 1]: direction = ['forward', 'reverse'][i] op = self.get_new_cell(bidirectional_cell, direction) output_data = Op._create_data_node( bidirectional_cell.graph, name=bidirectional_cell.out_node(0).name + '/Split/' + str(i), attrs={'shape': bidirectional_cell.out_node(0).shape.copy()} ) assert output_data.shape[1] == 2 output_data.shape[1] = 1 output_hidden = Op._create_data_node( bidirectional_cell.graph, name=bidirectional_cell.out_node(1).name + '/Split/' + str(i), attrs={'shape': bidirectional_cell.out_node(1).shape.copy()} ) assert output_hidden.shape[0] == 2 output_hidden.shape[0] = 1 data_nodes = [ output_data, output_hidden, ] if bidirectional_cell.op == 'LSTM': output_cell = Op._create_data_node( bidirectional_cell.graph, name=bidirectional_cell.out_node(2).name + '/Split/' + str(i), attrs={'shape': bidirectional_cell.out_node(2).shape.copy()} ) assert output_cell.shape[0] == 2 output_cell.shape[0] = 1 data_nodes.append(output_cell) all_outputs.append( op.create_node_with_data( inputs=[ bidirectional_cell.in_node(0), splitted_W[i], splitted_R[i], splitted_B[i], None, new_init_hiddens[i], new_init_cells[i] if bidirectional_cell.op == 'LSTM' else None, ], data_nodes=data_nodes ) ) return all_outputs
def replace_pattern(self, graph: Graph, match: dict): lstm = match['lstm'] # Build TensorIterator body first body = Graph(name=lstm.name + '/sub_graph') body.graph = graph.graph # 1. Input squeeze Reshape inputs = [ Op._create_data_node( body, lstm.name + '/inport/' + str(inp), { 'shape': lstm.in_node(inp).shape.copy(), 'value': lstm.in_node(inp).value.copy() if lstm.in_node(inp).value is not None and inp in [1, 2] else None }) for inp in [0, 4, 5, 1, 2] ] # X, WR, B, h_init, c_init inputs[0].shape[lstm.sequence_dim] = 1 input_squeeze = Squeeze( body, dict(name=lstm.name + '/input_squeeze', internal_layer_id=0)) squeeze_dim_data = Const(body, { 'name': lstm.name + '/input_squeeze_dim', 'value': [lstm.sequence_dim] }).create_node_with_data() inputs[0] = input_squeeze.create_node_with_data( [inputs[0], squeeze_dim_data], edge_attrs=[{ 'internal_port_id': 0 }]) # 2. Output unsqueeze Reshape outputs = [ Op._create_data_node( body, lstm.name + '/outport/' + str(out), { 'shape': lstm.out_node(out).shape.copy() if out in lstm.out_nodes() else lstm.in_node(4).shape.copy() }) for out in [0, 1] ] for out in outputs: add_opoutput(body, out.id, 0, False) outputs[0].shape = np.delete(outputs[0].shape, lstm.sequence_dim) output_unsqueeze = Unsqueeze( body, dict(name=lstm.name + 'output_unsqueeze', internal_layer_id=2)) unsqueeze_dim_data = Const( body, { 'name': lstm.name + '/output_unsqueeze_dim', 'value': [lstm.sequence_dim] }).create_node_with_data() # 3. LSTMCell lstm_cell_op = LSTMCell( body, dict(hidden_size=lstm.hidden_size, activations=lstm.activations, activation_alpha=lstm.activation_alpha, activation_beta=lstm.activation_beta, clip=lstm.clip, input_forget=lstm.input_forget, name=lstm.name + '/LSTMCell', internal_layer_id=1)) lstm_cell_node = lstm_cell_op.create_node_with_data( inputs, data_nodes=outputs, edge_attrs=[{}, { 'internal_port_id': 1 }, { 'internal_port_id': 2 }, { 'bin': 'weights' }, { 'bin': 'biases' }]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 4 lstm_cell_node[0].in_node().out_edge(1)['internal_port_id'] = 5 lstm_cell_node[0] = output_unsqueeze.create_node_with_data( [lstm_cell_node[0], unsqueeze_dim_data]) lstm_cell_node[0].in_node().out_edge(0)['internal_port_id'] = 3 add_opoutput(body, lstm_cell_node[0].id, 0, False) # 4. TensorIterator layer creating assert lstm.direction in ['forward', 'reverse'] if lstm.direction == 'forward': stride = 1 start = None end = None else: assert lstm.direction == 'reverse' stride = -1 start = -1 end = 0 output_port_map = [{ 'external_port_id': 3, 'internal_layer_id': 2, 'internal_port_id': 3, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }] # Adding h_state, c_state to outputs if len(lstm.out_nodes()) == 3: output_port_map.extend([{ 'external_port_id': 4, 'internal_layer_id': 1, 'internal_port_id': 4, }, { 'external_port_id': 5, 'internal_layer_id': 1, 'internal_port_id': 5, }]) ti_op = TensorIterator( graph, { 'name': lstm.name + '/TensorIterator', 'body': body, 'in_ports_count': 3, 'out_ports_count': len(lstm.out_nodes()), 'input_port_map': [ { 'external_port_id': 0, 'internal_layer_id': 0, 'internal_port_id': 0, 'axis': lstm.sequence_dim, 'stride': stride, 'start': start, 'end': end, 'part_size': 1, }, { 'external_port_id': 1, 'internal_layer_id': 1, 'internal_port_id': 1, }, { 'external_port_id': 2, 'internal_layer_id': 1, 'internal_port_id': 2, }, ], 'output_port_map': output_port_map, 'back_edges': [ { 'from_layer': 1, 'from_port': 4, 'to_layer': 1, 'to_port': 1, }, { 'from_layer': 1, 'from_port': 5, 'to_layer': 1, 'to_port': 2, }, ] }) assert sorted(lstm.out_nodes().keys()) == list(range(len(lstm.out_nodes()))), \ "There are gaps in output ports of LSTMSequence operation. Node {}".format(lstm.id) outs = ti_op.create_node_with_data( [lstm.in_node(i) for i in [0, 4, 5]], # X, h_init, c_init data_nodes=[ lstm.out_node(i) for i in range(len(lstm.out_nodes())) ], edge_attrs=[{ 'external_port_id': 0 }, { 'external_port_id': 1 }, { 'external_port_id': 2 }]) if not isinstance(outs, list): outs = list([outs]) graph.remove_node(lstm.id) outs[0].in_edge(0)['external_port_id'] = 3 for i, out in enumerate(outs[1:]): external_port_id = 4 + i out.in_edge()['external_port_id'] = external_port_id ti = outs[0].in_node() TensorIterator.cover_body_input_data_nodes_with_parameter_ops(ti) TensorIterator.cover_body_constant_data_nodes_with_const_ops(ti) TensorIterator.normalize_internal_ids(ti)
def split_multilayer_cell(self, graph: Graph, match: dict): """ Split one multilayer type=RNNSequence cell to num_layers consecutive cells. All parameters splits to parts for new num_layers cells. """ input = match['input'] rnn_layer = match['rnn_layer'] params = match['params'].value.copy() have_hidden = False if 2 in rnn_layer.in_nodes(): hidden_state_value = rnn_layer.in_node(2).value have_hidden = True have_cell = False if 3 in rnn_layer.in_nodes(): cell_state_value = rnn_layer.in_node(3).value have_cell = True direction = 2 if rnn_layer.has_num_directions else 1 num_layers = rnn_layer.num_layers input_size = input.shape[2] bsize = (2 * rnn_layer.hidden_size * direction * num_layers) * rnn_layer.multiplier size = rnn_layer.hidden_size * direction * rnn_layer.multiplier first_layer_params_size = (input_size + rnn_layer.hidden_size + 2) * size other_layer_params_size = (rnn_layer.hidden_size * direction + rnn_layer.hidden_size + 2) * size assert params.size == (first_layer_params_size + (num_layers - 1) * other_layer_params_size) input_node = input params_layer_size_count = 0 output_states = [[], []] param_w = params[0:len(params) - bsize] param_b = params[len(params) - bsize:] layer_bsize = (2 * rnn_layer.hidden_size * direction) * rnn_layer.multiplier for l in range(num_layers): params_layer_size = first_layer_params_size if l == 0 else other_layer_params_size layer_params_w = param_w[ params_layer_size_count:params_layer_size_count + (params_layer_size - layer_bsize)].copy() layer_params_b = param_b[layer_bsize * l:layer_bsize * l + layer_bsize].copy() layer_params = np.concatenate((layer_params_w, layer_params_b), axis=0) params_layer_size_count = params_layer_size_count + params_layer_size - layer_bsize op = self.get_new_cell(rnn_layer, l) name = str(rnn_layer.soft_get('name', rnn_layer.id)) params_value_node = Const( rnn_layer.graph, dict(name=name + '/LayerSplittedParamsLSTM/{}/'.format(l), value=layer_params)).create_node_with_data() if have_hidden: layer_hidden_state = hidden_state_value[l * direction:l * direction + direction] hidden_state_value_node = Const( rnn_layer.graph, dict(name=name + '/LayerSplittedHiddenState/{}/'.format(l), value=layer_hidden_state)).create_node_with_data() else: hidden_state_value_node = None if have_cell: layer_cell_state = cell_state_value[l * direction:l * direction + direction] cell_state_value_node = Const( rnn_layer.graph, dict(name=name + '/LayerSplittedCellState/{}/'.format(l), value=layer_cell_state)).create_node_with_data() else: cell_state_value_node = None if l < num_layers - 1: output_data = Op._create_data_node( rnn_layer.graph, name=rnn_layer.out_node(0).name + '/LayerSplit/' + str(l), attrs={'shape': rnn_layer.out_node(0).shape.copy()}) else: output_data = rnn_layer.out_node(0) # Output nodes creating: state_size = np.array( [input.shape[rnn_layer.batch_dim], rnn_layer.hidden_size], dtype=np.int64) if rnn_layer.has_num_directions: state_size = np.insert(state_size, 0, direction) output_hidden = Op._create_data_node( rnn_layer.graph, name=rnn_layer.out_node(1).name + '/LayerSplit/' + str(l), attrs={'shape': np.array(state_size)}) current_data_nodes = [output_data, output_hidden] if rnn_layer.op == 'LSTM': output_cell = Op._create_data_node( rnn_layer.graph, name=rnn_layer.out_node(2).name + '/LayerSplit/' + str(l), attrs={'shape': np.array(state_size)}) current_data_nodes.append(output_cell) data_nodes = op.create_node_with_data( inputs=[ input_node, params_value_node, hidden_state_value_node, cell_state_value_node ], data_nodes=current_data_nodes, ) input_node = data_nodes[0] output_states[0].append(data_nodes[1]) if rnn_layer.op == 'LSTM': output_states[1].append(data_nodes[2]) return output_states
def replace_pattern(graph, match: dict): # Here we will found all parts of TI: condition, inputs/outputs, back edges, body and create TensorIterator Op # and make all checks needed for TensorIteator work cond_data = match['condition'].out_node(0) time_data = match['condition'].out_node(1) if len( match['condition'].out_nodes()) > 1 else None name = match['condition'].name assert match['condition'].in_node(0).has_valid('value') back_edges = [] inputs = [] outputs = [] for node in cond_data.out_nodes(): if node['kind'] == 'op' and node['op'] == 'TensorIteratorBackEdge': back_edges.append(node.id) elif node['kind'] == 'op' and node['op'] == 'TensorIteratorInput': inputs.append(node.id) elif node['kind'] == 'op' and node['op'] == 'TensorIteratorOutput': outputs.append(node.id) if time_data is not None: for node in time_data.out_nodes(): if node['kind'] == 'op' and node['op'] == 'TensorIteratorInput': inputs.append(node.id) elif node['kind'] == 'op' and node[ 'op'] == 'TensorIteratorOutput': outputs.append(node.id) else: # something goes wrong here assert False condition = match['condition'] tensor_sequence_length = condition.in_node(0) graph.remove_nodes_from( [condition.id, cond_data.id, tensor_sequence_length.id]) if time_data is not None: graph.remove_nodes_from([time_data.id]) body_nodes, extra_inputs = get_body(graph, inputs, outputs) body_nodes = list(set(body_nodes) - set([cond_data])) inputs += extra_inputs assert all([node in graph.nodes() for node in body_nodes]) inputs = [Node(graph, node) for node in inputs] outputs = [Node(graph, node) for node in outputs] back_edges = [Node(graph, node) for node in back_edges] external_inputs = [{ 'external_data_id': node.in_node(1 if node.has_valid('axis') else 0), 'internal_data_id': node.out_node(0), 'axis': node.axis, 'start': node.start, 'end': node.end, 'stride': node.stride, 'part_size': node.part_size } for node in inputs] external_outputs = [{ 'external_data_id': node.out_node(0), 'internal_data_id': node.in_node(1 if node.has_valid('axis') else 0), 'axis': node.axis, 'start': node.start, 'end': node.end, 'stride': node.stride, 'part_size': node.part_size } for node in outputs] back_edges_data = [{ 'from_data_id': node.in_node(1), 'to_data_id': node.out_node(0), 'init_data_id': node.in_node(0), } for node in back_edges] body = Graph(name='body') body.graph = graph.graph body.add_nodes_from([(node, graph.node[node]) for node in body_nodes]) body.add_edges_from([ (u, v, k, d) for u, v, k, d in graph.edges(data=True, keys=True) if u in body_nodes and v in body_nodes ]) graph.remove_nodes_from(body_nodes + [match['condition'].id] + [inp.id for inp in inputs] + [out.id for out in outputs]) internal_id_count = 0 real_back_edges = [] for edge in back_edges_data: assert edge['from_data_id'].id in body.nodes() assert edge['to_data_id'].id in body.nodes() assert edge['init_data_id'].id in body.nodes() edge['from_data_id'] = Node(body, edge['from_data_id'].id) edge['to_data_id'] = Node(body, edge['to_data_id'].id) edge['init_data_id'] = Node(body, edge['init_data_id'].id) add_opoutput(body, edge['from_data_id'].id, 0, False) # Assign/reuse ids for the back-edge start; it comes from from_data_id assert len(edge['from_data_id'].in_nodes()) == 1 # layer id if not edge['from_data_id'].in_node().has_valid( 'internal_layer_id'): edge['from_data_id'].in_node( )['internal_layer_id'] = internal_id_count internal_id_count += 1 edge['from_layer'] = edge['from_data_id'].in_node( )['internal_layer_id'] # port id if 'internal_port_id' not in edge['from_data_id'].in_edge(): edge['from_data_id'].in_edge( )['internal_port_id'] = internal_id_count internal_id_count += 1 edge['from_port'] = edge['from_data_id'].in_edge( )['internal_port_id'] # Look at all consumers for a data that ends a back-edge # For each such consumer, there will be a separate back-edge (and input) current_real_back_edges = [] for _, consumer, key, edge_attrs in body.out_edges( edge['to_data_id'].id, data=True, keys=True): real_edge = {} real_edge.update( edge) # all real back_edges have the same back-edge start consumer = Node(body, consumer) if real_edge['to_data_id'].in_node().has_valid( 'internal_layer_id'): assert False real_edge['to_data_id'].out_node()['internal_layer_id'] = \ real_edge['to_data_id'].in_node().internal_layer_id elif not consumer.has_valid('internal_layer_id'): consumer['internal_layer_id'] = internal_id_count internal_id_count += 1 real_edge['to_layer'] = consumer['internal_layer_id'] assert 'internal_port_id' not in edge_attrs assert len(real_edge['init_data_id'].out_edges()) == 1 assert not 'internal_port_id' in real_edge[ 'init_data_id'].out_edge() edge_attrs['internal_port_id'] = internal_id_count internal_id_count += 1 real_edge['to_port'] = edge_attrs['internal_port_id'] real_edge['consumer'] = consumer real_edge['consumer_key'] = key real_edge['attrs'] = deepcopy(edge_attrs) current_real_back_edges.append(real_edge) # connect initial data node with each consumer providing actual edge attributes body.add_edges_from([ (real_edge['init_data_id'].id, real_edge['consumer'].id, real_edge['consumer_key'], real_edge['attrs']) for real_edge in current_real_back_edges ]) body.remove_nodes_from( [edge['to_data_id'].id, edge['to_data_id'].in_node().id]) real_back_edges += current_real_back_edges real_external_inputs = [] for ext_inp in external_inputs: assert ext_inp['external_data_id'].id not in body.nodes() assert ext_inp['internal_data_id'].id in body.nodes() ext_inp['internal_data_id'] = Node(body, ext_inp['internal_data_id'].id) if ext_inp['axis'] is not None: # Insert squeezing resize at input port that has partitioning shape = ext_inp['internal_data_id'].shape.copy() assert not ext_inp['internal_data_id'].has_valid('value') new_input_data = Op._create_data_node( body, ext_inp['internal_data_id'].name + '/UnsqueezedInput', dict(shape=np.insert(shape, ext_inp['axis'], 1))) dim = shape.copy() # try to do it dynamically reshapable along one of the axis # it is practically useful to reshape along batch dimension, but here we cannot detect where it is # so, we are guessing based onother transflormaions that it is the major dimension dim[0] = -1 reshape_op = Reshape( body, dict(name=ext_inp['internal_data_id'].name + '/InputSqueeze', dim=dim)) reshape_op.create_node_with_data( [new_input_data], data_nodes=[ext_inp['internal_data_id']]) ext_inp['internal_data_id'] = new_input_data ext_inp['internal_data_id']['is_input'] = True assert len(ext_inp['internal_data_id'].in_nodes()) == 0 ext_inp['external_port_id'] = internal_id_count internal_id_count += 1 for _, consumer, edge_attrs in body.out_edges( ext_inp['internal_data_id'].id, data=True): real_ext_inp = {} real_ext_inp.update(ext_inp) consumer = Node(body, consumer) if not consumer.has_valid('internal_layer_id'): consumer['internal_layer_id'] = internal_id_count internal_id_count += 1 if not 'internal_port_id' in edge_attrs: edge_attrs['internal_port_id'] = internal_id_count internal_id_count += 1 real_ext_inp['internal_layer_id'] = consumer[ 'internal_layer_id'] real_ext_inp['internal_port_id'] = edge_attrs[ 'internal_port_id'] real_external_inputs.append(real_ext_inp) for ext_out in external_outputs: assert ext_out['external_data_id'].id not in body.nodes() assert ext_out['internal_data_id'].id in body.nodes() ext_out['internal_data_id'] = Node(body, ext_out['internal_data_id'].id) if ext_out['axis'] is not None: # Insert unsqueezing resize at output port that has partitioning dim = ext_out['internal_data_id'].shape.copy() # trying to make it dynamically reshapable (see related comment above for the first Reshape) dim[0] = -1 assert not ext_out['internal_data_id'].has_valid('value') reshape_op = Reshape( body, dict(name=ext_out['internal_data_id'].name + '/OutputUnsqueeze', dim=np.insert(dim, ext_out['axis'], 1))) ext_out['internal_data_id'] = reshape_op.create_node_with_data( [ext_out['internal_data_id']]) # TODO: add here working with simple outputs add_opoutput(body, ext_out['internal_data_id'].id, 0, False) # assert len(ext_out['internal_data_id'].out_nodes()) == 0 assert len(ext_out['internal_data_id'].in_nodes()) == 1 if not 'internal_layer_id' in ext_out['internal_data_id'].in_node( ): ext_out['internal_data_id'].in_node( )['internal_layer_id'] = internal_id_count internal_id_count += 1 if not 'internal_port_id' in ext_out['internal_data_id'].in_edge(): ext_out['internal_data_id'].in_edge( )['internal_port_id'] = internal_id_count internal_id_count += 1 ext_out['internal_layer_id'] = ext_out['internal_data_id'].in_node( )['internal_layer_id'] ext_out['internal_port_id'] = ext_out['internal_data_id'].in_edge( )['internal_port_id'] ext_out['external_port_id'] = internal_id_count internal_id_count += 1 ti_op = TensorIterator( graph, { 'name': name + '/TensorIterator', 'body': body, 'in_ports_count': len(external_inputs), 'out_ports_count': len(external_outputs), 'input_port_map': [{ field: external_input[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end' ] } for external_input in real_external_inputs], 'output_port_map': [{ field: external_output[field] for field in [ 'external_port_id', 'internal_layer_id', 'internal_port_id', 'axis', 'stride', 'part_size', 'start', 'end' ] } for external_output in external_outputs], 'back_edges': [{ field: edge[field] for field in ['from_layer', 'from_port', 'to_layer', 'to_port'] } for edge in real_back_edges], }) ti_outs = ti_op.create_node_with_data( inputs=[inp['external_data_id'] for inp in external_inputs], edge_attrs=[{ 'external_port_id': inp['external_port_id'] } for inp in external_inputs], data_nodes=[out['external_data_id'] for out in external_outputs]) if not isinstance(ti_outs, list): ti_outs = [ti_outs] for i, out in enumerate(ti_outs): out.in_edge( )['external_port_id'] = external_outputs[i]['external_port_id']
def find_and_replace_pattern(self, graph: nx.MultiDiGraph): # Iterate over all data nodes and find all with >= 1 consumers data_nodes = [Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data'] for input_data in data_nodes: # We don't use constant data nodes if input_data.value is not None: continue input_shape = np.array(input_data.shape) # Get all StridedSlice consumers out_nodes = [node for node in input_data.out_nodes() if node.op == 'StridedSlice'] if len(out_nodes) < 1: continue valid_for_replacement = True # Detect dimension for splitting split_channel_dim = None for dim_id, s in enumerate(out_nodes[0].slices): l, r, stride = s.start, s.stop, s.step if l != 0 or r != input_shape[dim_id]: if split_channel_dim is None: split_channel_dim = dim_id else: valid_for_replacement = False # split_dims contains tuples with split range and output data node split_dims = [] for out_id, node in enumerate(out_nodes): # Check that StridedSlice op has no shrink_axis_mask attribute if not np.all([x == False for x in node.shrink_axis_mask]): valid_for_replacement = False # Check that StridedSlice op has stride eq 1 and splits only feature channel for id, s in enumerate(node.slices): l, r, stride = s.start, s.stop, s.step # We don't support StridedSlice with stride != 1 if stride != 1: valid_for_replacement = False if id == split_channel_dim: split_dims.append((s.start, s.stop, node.out_node())) if not valid_for_replacement: continue # Check feature split intersection final_data_nodes_list = [] sorted_split_dims = sorted(split_dims) size_splits = [] prev_r = 0 for l, r, out in sorted_split_dims: # Split dims shouldn't intersect if l < prev_r: valid_for_replacement = False # Save missing tensor part if l > prev_r: shape = np.array(input_shape) size_splits.append(l - prev_r) shape[split_channel_dim] = l - prev_r data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape, 'is_output': True}) final_data_nodes_list.append(data_node) prev_r = r size_splits.append(r - l) final_data_nodes_list.append(out) if prev_r > input_shape[split_channel_dim]: valid_for_replacement = False elif prev_r != input_shape[split_channel_dim]: # Add last part of tensor shape = input_shape.copy() shape[split_channel_dim] = input_shape[split_channel_dim] - prev_r size_splits.append(input_shape[split_channel_dim] - prev_r) data_node = Op._create_data_node(graph, 'fake_data', {'shape': shape, 'is_output': True}) final_data_nodes_list.append(data_node) if not valid_for_replacement: continue # Insert Split layer and remove old StridedSlice layers # 1. Remove connections from input_data to StridedSlice ops out_data_nodes = [] name_for_future_split = out_nodes[0].name for node in out_nodes: out_data_nodes.append(node.out_node()) graph.remove_edge(input_data.id, node.id) graph.remove_edge(node.id, node.out_node().id) graph.remove_node(node.id) log.debug("Removed: {}".format(node.id)) # 2. Create Split layer and reorder outputs split = SplitV(graph, dict(name=name_for_future_split + "/Split", axis=split_channel_dim, size_splits=size_splits)) split.create_node_with_data(inputs=[input_data], data_nodes=final_data_nodes_list)
def rnn_infer(node: Node, out_ports=None): """ General infer function for RNN, GRU, LSTM layers. Assume that 0-port input of node is input data for recurrent layer and node have attrs: hidden_size, """ if out_ports is None: out_ports = [] # 1. Necessary checks (from ONNX specification) assert node.batch_dim <= 1 assert node.sequence_dim <= 1 assert node.batch_dim != node.sequence_dim assert node.direction in ['forward', 'reverse', 'bidirectional'] if node.blobs_wrb: mark_input_bins(node, ['W', 'R', 'B']) else: mark_input_bins(node) # 2. Output shape calculations input_shape = node.in_node(0).shape assert len(input_shape) == 3 # Reshape input nodes for port in [2, 3]: if port in node.in_nodes() and len(node.in_node(port).in_nodes()) > 0 and \ 'zero_shapes' in node.in_node(port).in_node(): for i in node.in_node(port).in_node().zero_shapes: if node.in_node(port).shape[i] != input_shape[i]: node.in_node(port).value = np.repeat( node.in_node(port).value, input_shape[i], axis=i) node.in_node(port).shape[i] = input_shape[i] out_shape = np.array([ input_shape[node.sequence_dim], input_shape[node.batch_dim], node.hidden_size ], dtype=np.int64) if node.batch_dim == 0: out_shape = np.array([ input_shape[node.batch_dim], input_shape[node.sequence_dim], node.hidden_size ], dtype=np.int64) num_directions = 2 if node.direction in ['bidirectional'] else 1 if node.has_num_directions: if node.format == 'mxnet' and node.normalized is False: # In MXNet RNN layer return output with shape [seq_len, batch_size, hidden_size * num_directions] out_shape[-1] *= num_directions else: # ONNX-like, insert extra dimension to output shape for num_directions out_shape = np.insert(out_shape, 1, np.int64(num_directions)) node.out_node(0).shape = out_shape # 3. Extra outputs for hidden/cell states shape calculations (optional) state_size = np.array([input_shape[node.batch_dim], node.hidden_size], dtype=np.int64) if node.has_num_directions: state_size = np.insert(state_size, 0, num_directions) if node.multilayers: # For multilayer case state sizes from every layer will be concatenated by last axis num_layers = node.num_layers state_size[-1] *= num_layers for i in out_ports: # If node hasn't consumers for hidden/cells state -> create them if i not in node.out_nodes(): data_node = Op._create_data_node(node.graph, name=node.node + '/ExtraOutput/' + str(i), attrs={'executable': True}) if i not in node.out_ports(): node.add_output_port(i) node.graph.add_edge(node.id, data_node.id, key=0, out=i) add_opoutput(node.graph, data_node.id, 0, False) else: data_node = node.out_node(i) data_node.shape = state_size.copy()
def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): # add Reshape for shrink_axis_mask if True in match['strided_slice']['shrink_axis_mask']: log.info("StridedSlice op with shrink mask '{}' has been detected". format(match['strided_slice'].id)) node = match['strided_slice'] if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1: return shape_in = node.in_node().shape shape_out = node.out_node().shape dim = shape_out.copy() ss_shape = [] k = 0 # Don't permute reshape if channels were squeezed dont_permute = False if graph.graph['layout'] == 'NHWC' and node['shrink_axis_mask'][ -1] == True: dont_permute = True for i in range(0, len(node['shrink_axis_mask'])): if not node['shrink_axis_mask'][i]: ss_shape.append(shape_out[k]) k = k + 1 else: node['shrink_axis_mask'][i] = False ss_shape.append(1) out_node = node.out_node(0) # insert data node for StridedSlice data_node = Op._create_data_node( graph, node.name + "/Reshape_shrink_data", {'shape': ss_shape}) attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0]) graph.remove_edge(node.id, out_node.id) graph.add_edge(node.id, data_node.id, **attrs) # insert Reshape if dont_permute: reshape = Reshape( graph, dict(name=node.name + "/Reshape_shrink", dim=np.array(dim, dtype=np.int64), nchw_layout=True)) reshape_data_node = reshape.create_node_with_data( [data_node], reshape.attrs, data_nodes=[out_node]) reshape_data_node['nchw_layout'] = True else: reshape = Reshape( graph, dict(name=node.name + "/Reshape_shrink", dim=np.array(dim, dtype=np.int64))) reshape_data_node = reshape.create_node_with_data( [data_node], reshape.attrs, data_nodes=[out_node]) # add Reshape for new_axis_mask if True in match['strided_slice']['new_axis_mask']: log.info( "StridedSlice op with new axis mask '{}' has been detected". format(match['strided_slice'].id)) node = match['strided_slice'] if len(node.in_nodes()) != 4 or len(node.out_nodes()) != 1: return shape_in = node.in_node().shape shape_out = node.out_node().shape dim = shape_out.copy() ss_shape = [] for i in range(0, len(node['new_axis_mask'])): if not node['new_axis_mask'][i]: ss_shape.append(shape_out[i]) else: node['new_axis_mask'][i] = False out_node = node.out_node(0) # insert data node for StridedSlice data_node = Op._create_data_node(graph, node.name + "/Reshape_new_data", {'shape': ss_shape}) attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0]) graph.remove_edge(node.id, out_node.id) graph.add_edge(node.id, data_node.id, **attrs) # insert Reshape reshape = Reshape( graph, dict(name=node.name + "/Reshape_new", dim=np.array(dim, dtype=np.int64))) reshape_data_node = reshape.create_node_with_data( [data_node], reshape.attrs, data_nodes=[out_node])