def duplicate_shared_weights(graph: nx.MultiDiGraph): """ This function finds all const data nodes that have more that one consumer and then duplicate them """ data_nodes = [ Node(graph, id) for id in graph.nodes() if Node(graph, id).soft_get('kind') == 'data' ] for node in data_nodes: # Check that node has const values and more than one consumer if len(node.out_nodes()) > 1 and node.value is not None: # Here we delete all edges between base node and it's consumers (except first), and then duplicate this # node to connect with other consumers while len(node.out_nodes()) > 1: out_node = node.out_node(1) if len(graph.get_edge_data(node.id, out_node.id)) != 1: raise Error( 'There is more than one edge from {} node to {} node.'. format(node.id, out_node.id)) e_attrs = graph.get_edge_data(node.id, out_node.id)[0] graph.remove_edge(node.id, out_node.id) data = Op.create_input_data_node(graph, "Copy_{}".format(node.id), np.array(node.value), graph.node[node.id]) graph.add_edges_from([(data.id, out_node.id, e_attrs)])
def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): input = match['input'] lstm = match['lstm'] params = match['params'].value.copy() hidden_state = match['hidden_state'] cell_state = match['cell_state'] hidden_state_edge_attrs = deepcopy(graph.get_edge_data(hidden_state.id, lstm.id)[0]) cell_state_edge_attrs = deepcopy(graph.get_edge_data(cell_state.id, lstm.id)[0]) graph.remove_edge(match['params'].id, lstm.id) graph.remove_edge(match['hidden_state'].id, lstm.id) graph.remove_edge(match['cell_state'].id, lstm.id) self.repack_weights(graph, input, lstm, params) reshape = Reshape(graph, dict(dim=[lstm.in_node(0).shape[0], lstm.hidden_size])) if len(lstm.in_nodes()) > 2: hidden_state_edge_attrs['in'] = 3 new_init_h = reshape.create_node_with_data([hidden_state], attrs=dict(name=lstm.name + '/HiddenStateResize')) graph.add_edge(new_init_h.id, lstm.id, **hidden_state_edge_attrs) if len(lstm.in_nodes()) > 3: cell_state_edge_attrs['in'] = 4 new_init_c = reshape.create_node_with_data([cell_state], attrs=dict(name=lstm.name + '/CellStateResize')) graph.add_edge(new_init_c.id, lstm.id, **cell_state_edge_attrs)
def _insert_pooling(graph: nx.MultiDiGraph, first_node: Node, second_node: Node, spatial_dims): """ This function inserts point wise pooling layer between two nodes """ log.debug("STRIDE PROP: Insert pooling between {} and {}".format( first_node.name, second_node.name)) stride_prop = second_node.stride_prop assert len(graph.get_edge_data(first_node.id, second_node.id)) == 1 eattrs = graph.get_edge_data(first_node.id, second_node.id)[0] graph.remove_edge(first_node.id, second_node.id) pooling = Pooling( graph, dict(name='Pooling_', spatial_dims=spatial_dims, window=np.array([1, 1, 1, 1]), output_spatial_shape=None, stride=np.array(stride_prop), pad_spatial_shape=np.array([[0, 0], [0, 0]]), pad=np.array([[0, 0], [0, 0], [0, 0], [0, 0]]), pool_method='max', is_partial_inferred=False)) pooling_data = pooling.create_node_with_data([first_node]) _clean_fw_tensor_attrs(pooling_data) graph.add_edges_from([(pooling_data.id, second_node.id, eattrs)])
def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): relu = match['relu'] reshape1 = match['reshape1'] reshape2_data = match['reshape2_data'] conv = match['conv'] if np.max(conv.pad) == 0: return relu_input = relu.in_node() # Disconnect InputData-x->ReLU->Data-x->Reshape1 edge_attrs = graph.get_edge_data(relu.out_node().id, reshape1.id)[0] graph.remove_edge(relu_input.id, relu.id) graph.remove_edge(relu.out_node().id, reshape1.id) # Connect InputData-->Reshape1 graph.add_edges_from([(relu_input.id, reshape1.id, edge_attrs)]) # Insert ReLU: Reshape2Data->ReLU->Data->Convolution edge_attrs = graph.get_edge_data(reshape2_data.id, conv.id)[0] graph.remove_edge(reshape2_data.id, conv.id) graph.add_edges_from([(reshape2_data.id, relu.id, { 'in': 0 }), (relu.out_node().id, conv.id, edge_attrs)])
def find_and_replace_pattern(self, graph: nx.MultiDiGraph): data_nodes = [ Node(graph, node) for node in graph.node if Node(graph, node).kind == 'data' ] for node in data_nodes: # Get all requested shapes for current node # This mapping will contain pairs like {shape:[list of consumers nodes]} mapping = {} for consumer in node.out_nodes(): edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] if 'new_shape' in edge_attrs: if np.array_equal(edge_attrs['new_shape'], node.shape): continue new_shape = tuple([x for x in edge_attrs['new_shape']]) if not new_shape in mapping: mapping.update({new_shape: [consumer]}) else: mapping[new_shape].append(consumer) if node.has_valid('value'): # Check that requested shape are the same # In case if they are different, we duplicate them for shape_key in mapping.keys(): shape = list(shape_key) new_value = np.reshape(node.value, shape) node_copy = Op.create_input_data_node( graph, node.id + '/copy', value=np.array(new_value)) for consumer in mapping[shape_key]: edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] del edge_attrs['new_shape'] # Remove edge from previous data node and connect new data node with its consumer graph.remove_edge(node.id, consumer.id) graph.add_edge(node_copy.id, consumer.id, **edge_attrs) else: # Insert Reshape layer between data node and consumer for shape_key in mapping.keys(): shape = list(shape_key) reshape = Reshape(graph, attrs={ 'dim': shape, 'name': 'EltwiseReshapeNormalization' }) reshape_data = reshape.create_node_with_data(inputs=[node]) # Iterate over consumers and reconnect them to Reshape layer output for consumer in mapping[shape_key]: edge_attrs = graph.get_edge_data(node.id, consumer.id)[0] del edge_attrs['new_shape'] # Reconnect edge from original data node to Reshape output datanode graph.remove_edge(node.id, consumer.id) graph.add_edge(reshape_data.id, consumer.id, **edge_attrs)
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): node = match['op'] if not node.has_valid('bias') or (node.has_valid('bias') and node.bias == 1): return # Calculate scale value & create Const op scale_value = np.array(1. / (pow(node.bias, node.beta))) node.alpha /= node.bias const_node = Const(graph, dict(value=scale_value, shape=scale_value.shape)) # Get all outputs for LRN layer out_nodes = [node for node in node.out_nodes().values()] # Create Mul node with inputs mul_node = Mul(graph, dict(name=node.id + "/Mul_")) mnode = mul_node.create_node(inputs=[node, const_node.create_node()]) # Move edges from LRN to Mul node for out_node in out_nodes: edge_attrs = graph.get_edge_data(node.id, out_node.id)[0] graph.remove_edge(node.id, out_node.id) graph.add_edges_from([(mnode.id, out_node.id, edge_attrs)])
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): """ Need to find the pattern: SoftmaxActivation -> DetectionOutput DetectionOutput in IE expects flattened input from SoftMax, that is why there is the need to add Flatten layer Parameters ---------- graph : nx.MultiDiGraph Graph with loaded model. match : dict Patterns which were found in graph structure. """ softmax_activation = match['softmax_activation'] multi_box_detection = match['multi_box_detection'] softmax_activation['axis'] = -1 edge_data = graph.get_edge_data(softmax_activation.id, multi_box_detection.id) out_port = edge_data[0]['out'] in_port = edge_data[0]['in'] graph.remove_edge(softmax_activation.id, multi_box_detection.id) symbol_node = dict( op='Flatten', name=multi_box_detection.name + '/Reshape_', dim=[0, -1], axis=1, end_axis=-1 ) new_reshape_op = Reshape(graph, {'symbol_dict': symbol_node}) new_reshape_node = new_reshape_op.create_node([softmax_activation]) new_reshape_node['dim'] = [0, -1] create_edge(new_reshape_node, multi_box_detection, in_port=in_port, out_port=out_port)
def non_decreasing_magnetism(graph: MultiDiGraph, k: int) -> LayeredGraph: layers = k + 1 layered_graph = LayeredGraph(graph, layers) offset = layered_graph.max_node has_magnetic_outs = {node: magnetic in [arc[2][arc_type] for arc in graph.out_edges(node, data=True)] for node in graph.nodes} for u, v, key in graph.edges(keys=True): edge_data = graph.get_edge_data(u, v, key) edge_type = edge_data[arc_type] if edge_type == magnetic: for level in range(k): source = u + level * offset dest = v + (level+1) * offset layered_graph.add_edge(source, dest, **edge_data) layered_graph.add_edge(u + k * offset, v + k * offset, **edge_data) elif edge_type == non_magnetic: for level in range(k): source = u + level * offset dest = v + level * offset layered_graph.add_edge(source, dest, **edge_data) if not has_magnetic_outs[u]: layered_graph.add_edge(u + k * offset, v + k * offset, **edge_data) return layered_graph
def get_next_stop(graph: nx.MultiDiGraph, node_id: int, route_id: int, already_visited: list) -> tuple: for neighbour_id in graph.neighbors(node_id): if neighbour_id not in already_visited: edge = graph.get_edge_data(node_id, neighbour_id, route_id) if edge is not None: return edge['route_id'], node_id, neighbour_id, edge['duration'], edge['period'] return None
def barrier_restricted(graph: MultiDiGraph, k: int) -> LayeredGraph: layers = k + 1 layered_graph = LayeredGraph(graph, layers) offset = layered_graph.max_node for u, v, key in graph.edges(keys=True): edge_data = graph.get_edge_data(u, v, key) arc_type = edge_data['arc_type'] if arc_type == 'e': for level in range(layers): source = u + level * offset dest = v + level * offset layered_graph.add_edge(source, dest, **edge_data) elif arc_type == 'a': for level in range(k): source = u + level * offset dest = v + (level + 1) * offset layered_graph.add_edge(source, dest, **edge_data) layered_graph.add_edge(u + k * offset, v + k * offset, **edge_data) elif arc_type == 'b': layered_graph.add_edge(u + k * offset, v + k * offset, **edge_data) return layered_graph
def pad_op_transform(graph: nx.MultiDiGraph, match: dict): op = match['op'] pad_op = match['pad_op'] input_data = pad_op.in_node(0) pads = pad_op.in_node(1).value if len( pad_op.in_nodes()) == 2 else pad_op.pads if pad_op.mode != 'constant': log.info( 'The pad node "{}" with pad mode "{}" cannot be fused.'.format( pad_op.soft_get('name'), pad_op.mode)) return if pad_op.mode == 'constant' and pad_op.fill_value != 0.0: log.info('The pad node "{}" with non-zero fill value cannot be fused.'. format(pad_op.soft_get('name'))) return input_tensor_dims = len(match['pad_output'].shape) if np.any(pads[get_features_dim(op.graph.graph['layout'],input_tensor_dims)] != 0) or \ np.any(pads[get_batch_dim(op.graph.graph['layout'], input_tensor_dims)] != 0): log.info( 'The pad node "{}" with padding over feature/batch dimension cannot be fused.' .format(pad_op.soft_get('name'))) return op.pad += pads op.pad_spatial_shape = op.pad[op.spatial_dims] op['auto_pad'] = None assert (graph[match['pad_output'].node][match['op'].node][0]['in'] == 0) edge_attrs = graph.get_edge_data(match['pad_output'].id, match['op'].id)[0] graph.remove_edge(match['pad_output'].id, match['op'].id) graph.add_edge(input_data.id, match['op'].id, **{'in': 0, **edge_attrs})
def add_path_step(line_code: int, current_node, from_node_label: str, step_count: int, line_graph: nx.DiGraph, graph: nx.MultiDiGraph, taboo_list: Set[int]): if step_count == 5: return for neighbour in nx.neighbors(line_graph, current_node): if neighbour in taboo_list: continue edge = line_graph.get_edge_data(current_node, neighbour) cluster_number = edge['cluster'] to_node_label = str(step_count + 1) + '_' + str(cluster_number) existing_edges = graph.get_edge_data(from_node_label, to_node_label, default=None) if not edge_exists(existing_edges, line_code, current_node, neighbour): graph.add_edge(from_node_label, to_node_label, line_code=line_code, from_node=current_node, to_node=neighbour, cluster=cluster_number) taboo_list.add(neighbour) add_path_step(line_code, neighbour, to_node_label, step_count + 1, line_graph, graph, taboo_list)
def get_edges_data_from(graph: nx.MultiDiGraph, node_from_id: int, route_id: int) -> list: edges = set() for node_to_id in graph.neighbors(node_from_id): for edge in graph.get_edge_data(node_from_id, node_to_id).values(): if edge['route_id'] == route_id: edges.add( (int(edge['route_id']), node_from_id, node_to_id, int(edge['duration']), int(edge['period']))) return list(edges)
def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): """ This pass normalize FC layer Example: (2,16,512)-->FC->(2,16,101) => (2,16,512)-->Reshape-->(32,512)-->FC-->(32,101)-->Reshape-->(2,16,101) """ fc = match['fc'] fc_weights = fc.in_node(1) fc_output = match['fc_output'] fc_input = fc.in_node() input_shape = fc.in_node().shape if len(input_shape) <= 2 or np.prod( fc_input.shape[1:]) == fc_weights.shape[ fc_weights.input_channel_dim]: return # Insert Reshape to normalize input for FC layer that should be in [N,C] layout first_reshape_shape = np.array( [np.prod(input_shape[0:-1]), input_shape[-1]], dtype=np.int64) second_reshape_shape = np.array([*input_shape[0:-1], fc['out-size']], dtype=np.int64) fc_out_shape = np.array([np.prod(input_shape[0:-1]), fc['out-size']], dtype=np.int64) first_reshape = Reshape(graph, {'dim': np.array(first_reshape_shape)}) second_reshape = Reshape(graph, {'dim': np.array(second_reshape_shape)}) input_edge_attrs = graph.get_edge_data(fc_input.id, fc.id)[0] output_edge_attrs = graph.get_edge_data(fc.id, fc_output.id)[0] graph.remove_edge(fc_input.id, fc.id) graph.remove_edge(fc.id, fc_output.id) # Insert Reshapes before and after FullyConnected layer reshape_data = first_reshape.create_node_with_data(inputs=[fc_input]) graph.add_edge(reshape_data.id, fc.id, **input_edge_attrs) new_fc_output = Op.create_data_node(graph, fc, {'shape': fc_out_shape}, edge_attrs=output_edge_attrs) second_reshape.create_node_with_data(inputs=[new_fc_output], data_nodes=fc_output)
def merge_edge(g: nx.MultiDiGraph, u: str, v: str, key: str, data: dict, preserve: bool = True) -> dict: """ Merge edge ``u`` -> ``v`` into graph ``g``. Parameters ---------- g: nx.MultiDiGraph The target graph u: str Subject node id v: str Object node id key: str Edge key data: dict Node properties preserve: bool Whether or not to preserve conflicting properties Returns ------- dict The merged edge """ existing_edge = g.get_edge_data(u, v, key) for k, v in data.items(): if k in existing_edge: if k in CORE_EDGE_PROPERTIES: logging.debug(f"cannot modify core edge property '{k}': {existing_edge[k]} vs {v}") else: if isinstance(existing_edge[k], list): # append logging.debug(f"edge property '{k}' list a list; Appending {v} to {existing_edge[k]}") if isinstance(v, list): existing_edge[k].extend(v) else: existing_edge[k].append(v) else: if preserve: # convert to a list and append logging.debug(f"preserving edge property '{k}'; Appending {v} to {existing_edge[k]}") existing_edge[k] = [existing_edge[k]] existing_edge[k].append(v) else: # overwrite the value for key logging.debug(f"overwriting edge property '{k}'; Replacing {existing_edge[k]} with {v}") existing_edge[k] = v else: logging.debug(f"adding new edge property {k} to edge") existing_edge[k] = v return existing_edge
def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): """ Replace swapaxes layer: swapaxes -> Reshape """ swapaxes = match['swapaxes'] swapaxes_in_node = swapaxes.in_node() swapaxes_out_node = swapaxes.out_node() input_edge_attrs = graph.get_edge_data(swapaxes_in_node.id, swapaxes.id)[0] output_edge_attrs = graph.get_edge_data(swapaxes.id, swapaxes_out_node.id)[0] graph.remove_edge(swapaxes_in_node.id, swapaxes.id) graph.remove_edge(swapaxes.id, swapaxes_out_node.id) Reshape(graph, { 'dim': np.array(swapaxes_in_node.shape) }).create_node_with_data( inputs=[swapaxes_in_node], data_nodes=[swapaxes_out_node], edge_attrs=[input_edge_attrs, output_edge_attrs])
def dilated_convolution_action(graph: nx.MultiDiGraph, match: dict): conv = match['conv'] stb = match['space_to_batch'] bts = match['batch_to_space'] block_size = match['stb_bs'] input = match['input'] output = match['output'] stb_out = match['stb_output'] conv_out = match['conv_output'] in_edge_attrs = graph.get_edge_data(input.id, stb.id)[0] out_edge_attrs = graph.get_edge_data(bts.id, output.id)[0] graph.remove_edge(input.id, stb.id) graph.remove_edge(stb_out.id, conv.id) graph.remove_edge(conv.id, conv_out.id) graph.remove_edge(bts.id, output.id) conv.dilation[conv.spatial_dims] = block_size.value pad = match['stb_pad'].value - match['bts_crop'].value conv.pad[conv.spatial_dims] = [[pad[x][0], pad[x][1]] for x in range(len(pad))] conv['auto_pad'] = None graph.add_edges_from([ (input.id, conv.id, { 'in': 0, **in_edge_attrs }), (conv.id, output.id, { 'out': 0, **out_edge_attrs }), ])
def states_squeeze(self, graph: nx.MultiDiGraph, match: dict): lstm = match['lstm'] reshape = Reshape( graph, dict(dim=[lstm.in_node(0).shape[0], lstm.hidden_size])) if len(lstm.in_nodes()) > 3: init_h = lstm.in_node(5) edge_attrs = deepcopy(graph.get_edge_data(init_h.id, lstm.id)[0]) edge_attrs['in'] = 3 graph.remove_edge(init_h.id, lstm.id) new_init_h = reshape.create_node_with_data( [init_h], dict(name=lstm.name + '/HiddenStateResize')) graph.add_edge(new_init_h.id, lstm.id, **edge_attrs) if len(lstm.in_nodes()) > 4: init_c = lstm.in_node(6) edge_attrs = deepcopy(graph.get_edge_data(init_c.id, lstm.id)[0]) edge_attrs['in'] = 4 graph.remove_edge(init_c.id, lstm.id) new_init_c = reshape.create_node_with_data( [init_c], dict(name=lstm.name + '/CellStateResize')) graph.add_edge(new_init_c.id, lstm.id, **edge_attrs)
def replace_pattern(graph: nx.MultiDiGraph, match: dict): """ DetectionOutput layer has another order of inputs unlike mxnet. Need to reorder _contrib_MultiBoxDetection inputs for correct conversion to DetectionOutput layer. Parameters ---------- graph : nx.MultiDiGraph Graph with loaded model. """ multi_box_detection_node = match['multi_box_detection'] conf_node = multi_box_detection_node.in_node(0) loc_node = multi_box_detection_node.in_node(1) conf_edge_data = graph.get_edge_data(conf_node.id, multi_box_detection_node.id) conf_out_port = conf_edge_data[0]['out'] conf_in_port = conf_edge_data[0]['in'] loc_edge_data = graph.get_edge_data(loc_node.id, multi_box_detection_node.id) loc_out_port = loc_edge_data[0]['out'] loc_in_port = loc_edge_data[0]['in'] graph.remove_edge(conf_node.id, multi_box_detection_node.id) graph.remove_edge(loc_node.id, multi_box_detection_node.id) create_edge(loc_node, multi_box_detection_node, in_port=conf_in_port, out_port=conf_out_port) create_edge(conf_node, multi_box_detection_node, in_port=loc_in_port, out_port=loc_out_port)
def replace_op(self, graph: nx.MultiDiGraph, node: Node): in_node = node.in_node() out_nodes = [node for node in node.out_nodes().values()] graph.remove_edge(node.in_node().id, node.id) scalar_value_op = Const(graph, dict(value=node.scalar, shape=node.scalar.shape, symbol_dict={'name': node.id + '/const'})) add_op = Add(graph, dict(name=node.id + '/add_', symbol_dict={'name': node.id + '/add_'})) add_node = add_op.create_node(inputs=[in_node, scalar_value_op.create_node()]) for out_node in out_nodes: edge_attrs = graph.get_edge_data(node.id, out_node.id)[0] graph.remove_edge(node.id, out_node.id) graph.add_edges_from([(add_node.id, out_node.id, edge_attrs)]) return [add_node.id]
def toDiGraph(g_multi: nx.MultiDiGraph) -> nx.DiGraph: def edgeDict_to_set(ed): target = "computers" comp_set_set = frozenset( [v[target] for v in ed.values() if target in v.keys()]) return comp_set_set g_single = nx.DiGraph() for e in g_multi.edges(): s, t = e edgeDict = g_multi.get_edge_data(s, t) comp_set_set = edgeDict_to_set(edgeDict) if g_single.has_edge(s, t): comp_set_set = comp_set_set.union( g_single.get_edge_data(s, t)["computers"]) g_single.add_edge(s, t, computers=comp_set_set) return g_single
def simplify(g: nx.MultiDiGraph, input_names: Sequence = None, symbolic_function_map: Dict = None): """Compile computational graph `g` into a (possibly simplified) symbolic expression. Args: g (nx.MultiDiGraph): a computational graph symbolic_function_map ([Dict], optional): Map each function to a symbolic one in `sympy`. Defaults to None. If `None`, then the `DEFAULT_SYMBOLIC_FUNCTION_MAP` is used. input_names (Sequence): a list of names, each for one input. If `None`, then a default name "vi" is used for the i-th input. Return: a (simplified) symbol expression For example, `add(sub(3, 3), x)` may be simplified to `x`. Note that this method is used to simplify the **final** solution rather than during evolution. """ if symbolic_function_map is None: symbolic_function_map = DEFAULT_SYMBOLIC_FUNCTION_MAP # toplogical sort such that i appears before j if there is an edge i->j ts = list(nx.topological_sort(g)) d = dict() for node_id in ts: if node_id < 0: # inputs in CGP d[node_id] = sp.Symbol(f"v{-node_id}" if input_names is None else input_names[-node_id - 1]) else: # a function node inputs = [] # print(g.in_edges(node_id)) for input_node_id in g.predecessors(node_id): # possibly parallel edges for attr in g.get_edge_data(input_node_id, node_id).values(): inputs.append( (input_node_id, attr["weight"], attr["order"])) inputs.sort(key=operator.itemgetter(2)) args = (ip[1] * d[ip[0]] for ip in inputs) func = g.nodes[node_id]["func"] sym_func = symbolic_function_map[func] r = sym_func(*args) d[node_id] = sp.simplify(r) if PP_FORMULA_SIMPLIFICATION else r # the unique output is the last node return d[ts[-1]]
def process_neighbours(start_node_label: str, graph: nx.DiGraph, cluster_graph: nx.MultiDiGraph): for neighbour in cluster_graph.neighbors(start_node_label): node_attributes = nx.get_node_attributes(cluster_graph, neighbour) edge_data = cluster_graph.get_edge_data(start_node_label, neighbour) print("Neighbour:", neighbour) print("Node attributes:", node_attributes) print("Edge data:", len(edge_data)) graph.add_node(neighbour) # Create an edge with an attribute saying how many # edges there are between the nodes in the multigraph graph.add_edge(start_node_label, neighbour, number_of_edges=len(edge_data)) process_neighbours(neighbour, graph, cluster_graph)
def replace_pattern(graph: nx.MultiDiGraph, match: dict): reshape = match['reshape'] assert len(reshape.in_nodes()) > 0 if graph.graph['layout'] == 'NCHW' or reshape.has_and_set('nchw_layout') or\ reshape.soft_get('correct_data_layout') is True: return input_node = reshape.in_node() output_node = reshape.out_node() input_shape = input_node.shape output_shape = output_node.shape if len(input_shape) >= 4 and len(output_shape) == 3: # Check that we will permute some shapes in this Reshape by our permutation pass layout = 'NCHW' c_idx = get_features_dim(layout, len(input_shape)) hw_idx = [ get_width_dim(layout, len(input_shape)), get_height_dim(layout, len(input_shape)) ] if input_shape[c_idx] != 1 and np.any( input_shape[hw_idx] != [1, 1]): # then nhwc -> nchw permutation can change shapes significantly # We need to wrap up node with NCHW -> NHWC permutes and don't touch it later permutation = PermuteAttrs.get_nchw_to_nhwc_permutation( len(input_shape)) permutation_back = PermuteAttrs.get_nchw_to_nhwc_permutation( len(input_shape)) # 1. Insert input Permute # This Permute will permute input from original input layout to operation layout edge_attrs = graph.get_edge_data(input_node.id, reshape.id)[0] graph.remove_edge(input_node.id, reshape.id) permute_op = Permute(graph, { 'order': permutation.perm, 'name': reshape.name + '/Permute_' }) permute_data_node = permute_op.create_node_with_data( [input_node]) graph.add_edge(permute_data_node.id, reshape.id, **edge_attrs)
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): node = match['softmax'] if 'temperature' in node and node['temperature'] != 1.0: in_node = node.in_node() out_nodes = [node for node in node.out_nodes().values()] graph.remove_edge(node.in_node().id, node.id) temperature = np.array([1.0 / node.temperature]) scalar_value_op = Const( graph, dict(value=temperature, shape=temperature.shape, symbol_dict={'name': node.id + '/const'})) mul_op = Mul( graph, dict(name=node.id + '/mul_', symbol_dict={'name': node.id + '/mul_'})) mul_node = mul_op.create_node( inputs=[in_node, scalar_value_op.create_node()]) edge_attrs = graph.get_edge_data(node.id, out_nodes[0].id)[0] graph.add_edges_from([(mul_node.id, node.id, edge_attrs)])
def replace_op(self, graph: nx.MultiDiGraph, node: Node): input_node = node.in_node(0) port = graph.get_edge_data(input_node.id, node.id)[0]['out'] input_reshape_node = Reshape( graph, { 'name': '/Reshape/' + node.name, 'axis': 1, 'infer': Reshape.kaldi_infer }).create_node([(input_node, port)]) convolution_node = Convolution(graph, node.attrs()).create_node( [input_reshape_node]) output_reshape_node = Reshape( graph, { 'name': node.name + '/Reshape/', 'axis': 1, 'infer': Reshape.kaldi_infer }).create_node([convolution_node]) return [output_reshape_node.id]
def find_and_replace_pattern(self, graph: nx.MultiDiGraph): layout = graph.graph['layout'] for n in list(graph.nodes()): if 'type' in graph.node[ n] and graph.node[n]['type'] == 'Eltwise' and get_value_id( Node(graph, n)) is None: eltwise_op_node = Node(graph, n) out_shape = eltwise_op_node.out_node().shape if 4 <= len(out_shape) <= 5: out_features = out_shape[get_features_dim( layout, len(out_shape))] for port, node in eltwise_op_node.in_nodes().items(): if len(node.shape) != len(out_shape) and len( node.shape ) == 1 and out_features == node.shape[0]: in_atts = deepcopy( graph.get_edge_data(node.id, n)[0]) graph.remove_edge(node.id, n) new_shape = shape_for_layout( layout, batch=1, features=out_features, height=1, width=1, depth=1 if len(out_shape) == 5 else None) reshape_data_op = Reshape(graph, attrs={ 'dim': new_shape, 'name': node.id + '/Broadcast' }) reshape_data_node = reshape_data_op.create_node_with_data( [node]) graph.add_edge(reshape_data_node.id, eltwise_op_node.id, **in_atts)
def replace_sub_graph(self, graph: nx.MultiDiGraph, match: dict): # This replacer replace ImageScalar operation to Mul->Add sequence # Also it check that weights and biases are good op = match['op'] # Check that weights and biases are not useless has_bias, has_weights = True, True if all([x == 1 for x in np.nditer(op.scale)]): has_weights = False if all([x == 0 for x in np.nditer(op.bias)]): has_bias = False # Get all outputs for op node out_nodes = [node for node in op.out_nodes().values()] assert len(op.in_nodes()) == 1 last_node = op.in_node() # Create Mul & Add nodes if has_weights: mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)) mul_op = Mul(graph, dict(name=op.id + '/mul_')) last_node = mul_op.create_node(inputs=[last_node, mul_weights.create_node()]) if has_bias: add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)) add_op = Add(graph, dict(name=op.id + '/add_')) last_node = add_op.create_node(inputs=[last_node, add_bias.create_node()]) # Move edges from ImageScaler to last_node (Mul or Add) for out_node in out_nodes: edge_attrs = graph.get_edge_data(op.id, out_node.id)[0] graph.remove_edge(op.id, out_node.id) graph.add_edges_from([(last_node.id, out_node.id, edge_attrs)]) # Disconnect ImageScalar node graph.remove_edge(op.in_node().id, op.id)
def replace_pattern(self, graph: nx.MultiDiGraph, match: dict): if graph.graph['layout'] != 'NHWC': return if self.is_reshape_bad(match['reshape_pack'], match['reshape_unpack'], match['strided_slice']): log.info("Reshape that pack/unpack several dimensions detected {}". format(match['reshape_pack'].id)) node_split = match['reshape_split'] # insert Permute before reshape data_node = Op._create_data_node( graph, node_split.name + "/Permute_before_data") permute_before = Permute( graph, dict(name=node_split.name + "/Permute_before", order=np.array([0, 2, 3, 1]))) in_node = node_split.in_node(0) attrs = deepcopy(graph.get_edge_data(in_node.id, node_split.id)[0]) graph.remove_edge(in_node.id, node_split.id) permute_before_node = permute_before.create_node_with_data( [in_node], permute_before.attrs, data_nodes=[data_node]) graph.add_edge(permute_before_node.id, node_split.id, **attrs) node = match['reshape_pack'] node['nchw_layout'] = True new_reshape_shape = np.concatenate( (np.array([node.in_node(0).shape[0]]), np.array([np.prod(node.in_node(0).shape[[1, 2, 3]])]), np.array([node.in_node(0).shape[-1]]))) node.dim = new_reshape_shape # insert Permute after reshape data_node = Op._create_data_node(graph, node.name + "/Permute_after_data", {'shape': node.dim}) permute_after = Permute( graph, dict(name=node.name + "/Permute_after", order=np.array([0, 2, 1]))) out_node = node.out_node(0) out_node.shape = new_reshape_shape[np.array([0, 2, 1])] attrs = deepcopy(graph.get_edge_data(node.id, out_node.id)[0]) graph.remove_edge(node.id, out_node.id) permute_after_node = permute_after.create_node_with_data( [data_node], permute_after.attrs, data_nodes=[out_node]) graph.add_edge(node.id, data_node.id, **attrs) # update softmax shape node_softmax = match['softmax'] node_softmax.out_node(0).shape = out_node.shape # revert strided slice and reshape node_ss = match['strided_slice'] node_unpack = match['reshape_unpack'] unpack_out = node_unpack.out_node(0).id ss_out = node_ss.out_node(0).id #gather edge attributes soft_reshape_attrs = deepcopy( graph.get_edge_data( node_softmax.out_node(0).id, node_unpack.id)[0]) reshape_data_attrs = deepcopy( graph.get_edge_data(node_unpack.id, unpack_out)[0]) reshape_ss_attrs = deepcopy( graph.get_edge_data(unpack_out, node_ss.id)[0]) ss_out_attrs = deepcopy(graph.get_edge_data(node_ss.id, ss_out)[0]) #remove all edges in Softmax->Reshape->StridedSlice chain graph.remove_edge(node_softmax.out_node(0).id, node_unpack.id) graph.remove_edge(node_unpack.id, unpack_out) graph.remove_edge(unpack_out, node_ss.id) graph.remove_edge(node_ss.id, ss_out) #add new edges to get chain Softmax->StridedSlice->Reshape graph.add_edge( node_softmax.out_node(0).id, node_ss.id, **soft_reshape_attrs) graph.add_edge(node_ss.id, unpack_out, **reshape_data_attrs) graph.add_edge(unpack_out, node_unpack.id, **reshape_ss_attrs) graph.add_edge(node_unpack.id, ss_out, **ss_out_attrs) #update output shape and parameters for StridedSlice node_ss.out_node(0).shape = np.zeros(3) node_ss.out_node(0).shape[0] = out_node.shape[0] node_ss.out_node(0).shape[1] = 1 node_ss.out_node(0).shape[2] = out_node.shape[2] old_slices = node_ss.slices.copy() node_ss.slices = [] node_ss.slices.append(old_slices[0]) node_ss.slices.append(old_slices[-1]) node_ss.slices.append(slice(0, out_node.shape[2], 1)) node_ss.shrink_axis_mask = [False, False, False] node_ss.new_axis_mask = [False, False, False] #update Reshape attribute node_unpack.dim = np.delete(node_unpack.dim, 4) #prevent permute for reshape because it gives wrong result node_unpack['nchw_layout'] = True node_unpack.out_node(0)['nchw_layout'] = True
def replace_pattern(graph: nx.MultiDiGraph, match: dict): time_len = match['concatenated_hidden_states'].shape[0] """ Working with concatenated_cell_states_data part first, because IE TensorIterator primitive doesn't have concatenated cell states output and if we can not collepse it, then we does not support this type of BlockLSTM We simplify the sub-graph below by taking another output of BlockLSTM: concatenated cell states over the whole time sequence -> last cell state BlockLSTM || out 1 (concatenated cell states comming out of BlockLSTM) \/ in 1 ConcatV2 || (concatenation with initial state or another unused data) \/ Reshape || \/ Gather (taking the last cell state from previous BlockLSTM, if Gather indexes == time_len) """ # check that there are no other consumers of concatenated_cell_states_data data flow valid_output_names = [ 'concat_1', 'concat_1_data', 'reshape_1', 'reshape_1_data', 'gather_1', 'gather_1_data' ] valid_output_node_ids = [match[name].id for name in valid_output_names] node_names_to_check_outputs = [ 'concatenated_cell_states_data', 'concat_1_data', 'reshape_1_data' ] for name in node_names_to_check_outputs: for node in match[name].out_nodes(): if node.id not in valid_output_node_ids: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) # check that we really take the last cell state data by Gather gather_indexes = match['gather_1'].in_node(1).value if len(gather_indexes) == 1: gather_index = gather_indexes[0] else: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) if gather_index != time_len: raise Error( "BlockLSTM node {} has output which contains concatenated cell states over the whole " "time sequence. It is not replaceable by another output and is not supported " "originally".format(match['BlockLSTM'].id)) """ We passed #1 and #2 stages from class description. It means that we can translate the rest of the pattern to LSTMSequence even without following optimizations """ node = match['BlockLSTM'] weights_node = node.in_node(1) biases_node = node.in_node(2) shift_const = node.forget_bias # Assign temporary shape for them for easier manipulation # TF stores weights in IO order input_size = node.in_node(0).shape[-1] hidden_size = node.in_node(3).shape[-1] weights = weights_node.value biases = biases_node.value assert weights.shape[ 0] == input_size + hidden_size, "weights.shape={} input_size={} hidden_size={}".format( weights.shape, input_size, hidden_size) assert weights.shape[1] == biases.shape[ 0] == 4 * hidden_size, "weights.shape={} biases.shape={} hidden_size={}".format( weights.shape, biases.shape, hidden_size) weights = weights.reshape([ weights.shape[0], 4, # gates hidden_size ]) biases = biases.reshape([ 4, # gates hidden_size ]) # Reorder gates icfo --> fico for both weights and biases gate_reorder = [2, 0, 1, 3] weights = np.take(weights, gate_reorder, axis=1) biases = np.take(biases, gate_reorder, axis=0) # shift_const.value should be added to the first 1/4th part of the biases (f-gate: 0) # Note: in case of moving this code up before gate reordering, the addition # should be applied at different place biases[0] += shift_const # Return to the original shapes weights = weights.reshape([weights.shape[0], -1]) biases = biases.flatten() # TF stores weights in IO, but IE requires it in OI: transpose weights = weights.transpose() weights_node.value = weights weights_node.shape = np.array(weights.shape, dtype=np.int64) biases_node.value = biases biases_node.shape = np.array(biases.shape, dtype=np.int64) attrs = dict( graph.get_edge_data(match['gather_1'].id, match['gather_1_data'].id)[0]) attrs.update({'out': 2}) graph.remove_edge(match['BlockLSTM'].id, match['concatenated_cell_states_data'].id) graph.remove_edge(match['gather_1'].id, match['gather_1_data'].id) graph.add_edge(match['BlockLSTM'].id, match['gather_1_data'].id, **attrs) match['BlockLSTM'].op = 'LSTMSequence' match['BlockLSTM']['sequence_dim'] = 0 # TF reference match['BlockLSTM']['batch_dim'] = 1 # TF reference match['BlockLSTM']['direction'] = 'forward' # TF reference match['BlockLSTM']['hidden_size'] = match[ 'concatenated_hidden_states'].shape[-1] match['BlockLSTM']['format'] = 'tf' """ Optional #3 optimization from class description following """ data_to_mul = [ n for n in match['mul'].in_nodes().values() if n.id != match['concatenated_hidden_states'].id ] if len(data_to_mul) != 1: return # unexpected type of mul data_to_mul = data_to_mul[0] if not data_to_mul.has_valid('value'): return # unexpected type of mul data_to_mul_value = data_to_mul.value if not np.all(data_to_mul_value == 1): return # unexpected type of mul # remove useless mul attrs = dict( graph.get_edge_data(match['BlockLSTM'].id, match['concatenated_hidden_states'].id)[0]) graph.remove_edge(match['BlockLSTM'].id, match['concatenated_hidden_states'].id) graph.remove_edge(match['mul'].id, match['mul_data'].id) graph.add_edge(match['BlockLSTM'].id, match['mul_data'].id, **attrs) # find true usages of concatenated hidden states data (not last hidden state) valid_output_names = [ 'mul_data', 'concat_0', 'concat_0_data', 'reshape_0', 'reshape_0_data', 'gather_0', 'gather_0_data' ] valid_output_node_ids = [match[name].id for name in valid_output_names] node_names_to_check_outputs = [ 'mul_data', 'concat_0_data', 'reshape_0_data' ] list_of_concatenated_hidden_states_children_node_ids = [] for name in node_names_to_check_outputs: for node in match[name].out_nodes(): if node.id not in valid_output_node_ids: list_of_concatenated_hidden_states_children_node_ids.append( node.id) if len(list_of_concatenated_hidden_states_children_node_ids) != 1: return # not supported placement of patten conacenated_child_node_id = list_of_concatenated_hidden_states_children_node_ids[ 0] if conacenated_child_node_id != match[ 'after_mul_op_to_the_rest_of_model'].id: return # not supported placement of patten gather_indexes = match['gather_0'].in_node(1).value if len(gather_indexes) == 1: gather_index = gather_indexes[0] else: return # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is if gather_index != time_len: return # we have to translate this type of BlockLSTM to LSTMSequence to TensorIterator as is attrs = dict( graph.get_edge_data(match['gather_0'].id, match['gather_0_data'].id)[0]) attrs.update({'out': 1}) graph.remove_edge(match['mul_data'].id, match['concat_0'].id) graph.remove_edge(match['gather_0'].id, match['gather_0_data'].id) graph.add_edge(match['BlockLSTM'].id, match['gather_0_data'].id, **attrs)