def get_reverse_nodes_after_y_output(g, lstm_bw): bw_consumers = g.find_output_consumers(lstm_bw.output[0]) # todo: figure out a better way to remove reverse op squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"] s_cnt = len(squeeze_nodes) if s_cnt == 1: s = squeeze_nodes[0] trans_nodes = g.find_output_consumers(s.output[0]) if len(trans_nodes) == 1: if trans_nodes[0].type == "Transpose": reverse_nodes = g.find_output_consumers( trans_nodes[0].output[0]) elif is_reverse_op(trans_nodes[0]): reverse_nodes = trans_nodes else: log.debug("not found reverse op, unexpected") return None are_all_reverse = all( [is_reverse_op(r_op) for r_op in reverse_nodes]) if are_all_reverse: return reverse_nodes log.debug("bw y output is used followed by reverse node") return None log.debug("unexpected number of transpose after LSTM 1st output:%s", s_cnt) return None log.debug("unexpected number of squeeze following LSTM 1st output:%s", s_cnt) return None
def rewrite_bidirectional_lstms(g, ops): fw_lstm = {} bw_lstm = {} for n in g.get_nodes(): if n.type != "LSTM": continue input_id = n.input[0] temp = n.inputs[0] is_backward_lstm = False if temp.type == "Transpose": input_id = temp.input[0] temp = temp.inputs[0] if is_reverse_op(temp): input_id = temp.input[0] is_backward_lstm = True if is_backward_lstm: # make sure reverse lstm output will be reversed back if get_reverse_nodes_after_y_output(g, n): log.debug("find bw lstm %s", input_id) bw_lstm[input_id] = [input_id, n] else: log.debug("find fw lstm %s", input_id) fw_lstm[input_id] = [input_id, n] bilstm_input = list(set(fw_lstm.keys()).intersection(bw_lstm.keys())) bi_lstms = [(fw_lstm[input_id], bw_lstm[input_id]) for input_id in bilstm_input] return process_bilstm(g, bi_lstms)
def rewrite_bidirectional_grus(g, ops): """ return: list of tuple, format of tuple is ((fw input_id, fw onnx gru node), (bw input_id, bw onnx gru node)), and fw input_id equals to bw input_id """ fw_gru = {} bw_gru = {} for n in g.get_nodes(): if n.type != "GRU": continue input_id = n.input[0] temp = n.inputs[0] is_backward_gru = False if temp.type == "Transpose": input_id = temp.input[0] temp = temp.inputs[0] if is_reverse_op(temp): input_id = temp.input[0] is_backward_gru = True if is_backward_gru: log.debug("find bw gru %s", input_id) bw_gru[input_id] = [input_id, n] else: log.debug("find fw gru %s", input_id) fw_gru[input_id] = [input_id, n] # when fw_gru has same input as bw_gru, then it may be a bi gru bigru_input = list(set(fw_gru.keys()).intersection(bw_gru.keys())) bi_grus = [(fw_gru[input_id], bw_gru[input_id]) for input_id in bigru_input] return process_bigru(g, bi_grus)
def rewrite_bidirectional_lstms(g, ops): fw_lstm = {} bw_lstm = {} for n in g.get_nodes(): if n.type != "LSTM": continue input_id = n.input[0] temp = n.inputs[0] is_backward_lstm = False if temp.type == "Transpose": input_id = temp.input[0] temp = temp.inputs[0] if is_reverse_op(temp): input_id = temp.input[0] is_backward_lstm = True if is_backward_lstm: # if output 0 is consumed, and there is no reverse after the lstm output. # it's not reversed lstm if g.find_output_consumers(n.output[0]) and not get_reverse_nodes_after_y_output(g, n): continue logger.debug("find bw lstm %s", input_id) bw_lstm[input_id] = [input_id, n] else: logger.debug("find fw lstm %s", input_id) fw_lstm[input_id] = [input_id, n] bilstm_input = list(set(fw_lstm.keys()).intersection(bw_lstm.keys())) bi_lstms = [(fw_lstm[input_id], bw_lstm[input_id]) for input_id in bilstm_input] return process_bilstm(g, bi_lstms)
def slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm, lstm_output_index, all_nodes, to_remove): fw_consumers = g.find_output_consumers(lstm_fw.output[lstm_output_index]) bw_consumers = g.find_output_consumers(lstm_bw.output[lstm_output_index]) if lstm_output_index == 0: axis = 1 # remove reverse op for lstm_bw # todo: figure out a better way to remove reverse op squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"] s_cnt = len(squeeze_nodes) if s_cnt > 1: raise ValueError( "unexpected number of squeeze following LSTM 1st output") elif s_cnt == 1: s = squeeze_nodes[0] trans_nodes = g.find_output_consumers(s.output[0]) if len(trans_nodes) == 1: if trans_nodes[0].type == "Transpose": reverse_nodes = g.find_output_consumers( trans_nodes[0].output[0]) elif is_reverse_op(trans_nodes[0]): reverse_nodes = trans_nodes else: raise ValueError("not found reverse op, unexpected") for r_op in reverse_nodes: log.debug("remove reverse op called %s", r_op.name) g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0]) to_remove.append(r_op.name) else: raise ValueError( "unexpected number of transpose after LSTM 1st output") elif lstm_output_index in [1, 2]: axis = 0 else: raise ValueError("LSTM only should has 3 outputs.") if fw_consumers: attr = {"axes": [axis], "starts": [0], "ends": [1]} slice_node_fw = make_onnx_node(g, "Slice", [bi_lstm.output[lstm_output_index]], attr) all_nodes.append(slice_node_fw) g.replace_all_inputs(fw_consumers, lstm_fw.output[lstm_output_index], slice_node_fw.output[0]) if bw_consumers: attr = {"axes": [axis], "starts": [1], "ends": [2]} slice_node_bw = make_onnx_node(g, "Slice", [bi_lstm.output[lstm_output_index]], attr) all_nodes.append(slice_node_bw) g.replace_all_inputs(bw_consumers, lstm_bw.output[lstm_output_index], slice_node_bw.output[0])
def rewrite_bidirectional_grus(g, ops): """ return: list of tuple, format of tuple is ((fw input_id, fw onnx gru node), (bw input_id, bw onnx gru node)), and fw input_id equals to bw input_id """ fw_gru = {} bw_gru = {} for n in g.get_nodes(): if n.type != "GRU": continue input_id = n.input[0] temp = n.inputs[0] is_backward_gru = False if temp.type == "Transpose": input_id = temp.input[0] temp = temp.inputs[0] if is_reverse_op(temp): input_id = temp.input[0] is_backward_gru = True if is_backward_gru: # if output 0 is consumed, and there is no reverse after the gru output. # it's not reversed gru if g.find_output_consumers( n.output[0]) and not get_reverse_nodes_after_y_output( g, n): continue logger.debug("find bw gru %s", input_id) bw_gru[input_id] = [input_id, n] else: logger.debug("find fw gru %s", input_id) fw_gru[input_id] = [input_id, n] # when fw_gru has same input as bw_gru, then it may be a bi gru bigru_input = list(set(fw_gru.keys()).intersection(bw_gru.keys())) bi_grus = [(fw_gru[input_id], bw_gru[input_id]) for input_id in bigru_input] return process_bigru(g, bi_grus)