def get_reverse_nodes_after_y_output(g, rnn_bw): bw_consumers = g.find_output_consumers(rnn_bw.output[0]) # todo: figure out a better way to remove reverse op squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"] s_cnt = len(squeeze_nodes) if s_cnt == 1: s = squeeze_nodes[0] trans_nodes = g.find_output_consumers(s.output[0]) if len(trans_nodes) == 1: if trans_nodes[0].type == "Transpose": reverse_nodes = g.find_output_consumers( trans_nodes[0].output[0]) elif utils.is_tf_reverse_op(trans_nodes[0]): reverse_nodes = trans_nodes else: logger.debug("not found reverse op, unexpected") return [] are_all_reverse = all( [utils.is_tf_reverse_op(r_op) for r_op in reverse_nodes]) if are_all_reverse: return reverse_nodes logger.debug("bw y output is used followed by reverse node") return [] logger.debug("unexpected number of transpose after RNN 1st output:%s", s_cnt) return [] logger.debug("unexpected number of squeeze following RNN 1st output:%s", s_cnt) return []
def find_bidirectional_rnns(g, ops, rnn_type): """ Find possible bidirectional rnns, return: list of tuple, Format of tuple is (fw onnx rnn node, bw onnx rnn node). """ fw_rnns = defaultdict(list) bw_rnns = defaultdict(list) for n in g.get_nodes(): if n.type != onnx_rnn_type_mapping[rnn_type]: continue input_id = n.input[0] temp = n.inputs[0] is_bw = False if temp.type == "Transpose": input_id = temp.input[0] temp = temp.inputs[0] if utils.is_tf_reverse_op(temp): input_id = temp.input[0] is_bw = True if is_bw: # if output 0 is consumed and there is no reverse after the 1st output. # it's not backward rnn. if g.find_output_consumers( n.output[0]) and not get_reverse_nodes_after_y_output( g, n): logger.warning( "rnn %s following Reverse op isn't the part of bi-rnn.", n.name) continue logger.debug("find bw rnn %s", input_id) bw_rnns[input_id].append(n) else: logger.debug("find fw rnn %s", input_id) fw_rnns[input_id].append(n) # fw_rnn and bw_rnn must share the same input birnn_input = list(set(fw_rnns.keys()).intersection(bw_rnns.keys())) bi_rnns = [] matched_rnn = [] for inp in birnn_input: fw_rnn = fw_rnns[inp] bw_rnn = bw_rnns[inp] # it's possible several bi-rnns share the same input for fw_n in fw_rnn: for bw_n in bw_rnn: if belong_to_birnn(g, fw_n, bw_n, rnn_type) and \ not fw_n in matched_rnn and not bw_n in matched_rnn: logger.debug("found birnn comprising %s and %s", fw_n.name, bw_n.name) bi_rnns.append((fw_n, bw_n)) matched_rnn.extend([fw_n, bw_n]) return bi_rnns
def rewrite_bidirectional_grus(g, ops): """ return: list of tuple, format of tuple is ((fw input_id, fw onnx gru node), (bw input_id, bw onnx gru node)), and fw input_id equals to bw input_id """ fw_gru = {} bw_gru = {} for n in g.get_nodes(): if n.type != "GRU": continue input_id = n.input[0] temp = n.inputs[0] is_backward_gru = False if temp.type == "Transpose": input_id = temp.input[0] temp = temp.inputs[0] if is_tf_reverse_op(temp): input_id = temp.input[0] is_backward_gru = True if is_backward_gru: # if output 0 is consumed, and there is no reverse after the gru output. # it's not reversed gru if g.find_output_consumers( n.output[0]) and not get_reverse_nodes_after_y_output( g, n): continue logger.debug("find bw gru %s", input_id) bw_gru[input_id] = [input_id, n] else: logger.debug("find fw gru %s", input_id) fw_gru[input_id] = [input_id, n] # when fw_gru has same input as bw_gru, then it may be a bi gru bigru_input = list(set(fw_gru.keys()).intersection(bw_gru.keys())) bi_grus = [(fw_gru[input_id], bw_gru[input_id]) for input_id in bigru_input] return process_bigru(g, bi_grus)
def rewrite_bidirectional_lstms(g, ops): fw_lstm = {} bw_lstm = {} for n in g.get_nodes(): if n.type != "LSTM": continue input_id = n.input[0] temp = n.inputs[0] is_backward_lstm = False if temp.type == "Transpose": input_id = temp.input[0] temp = temp.inputs[0] if is_tf_reverse_op(temp): input_id = temp.input[0] is_backward_lstm = True if is_backward_lstm: # if output 0 is consumed, and there is no reverse after the lstm output. # it's not reversed lstm if g.find_output_consumers( n.output[0]) and not get_reverse_nodes_after_y_output( g, n): continue logger.debug("find bw lstm %s", input_id) bw_lstm[input_id] = [input_id, n] else: logger.debug("find fw lstm %s", input_id) fw_lstm[input_id] = [input_id, n] bilstm_input = list(set(fw_lstm.keys()).intersection(bw_lstm.keys())) bi_lstms = [(fw_lstm[input_id], bw_lstm[input_id]) for input_id in bilstm_input] return process_bilstm(g, bi_lstms)