コード例 #1
0
def get_reverse_nodes_after_y_output(g, rnn_bw):
    bw_consumers = g.find_output_consumers(rnn_bw.output[0])

    # todo: figure out a better way to remove reverse op
    squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"]
    s_cnt = len(squeeze_nodes)
    if s_cnt == 1:
        s = squeeze_nodes[0]
        trans_nodes = g.find_output_consumers(s.output[0])
        if len(trans_nodes) == 1:
            if trans_nodes[0].type == "Transpose":
                reverse_nodes = g.find_output_consumers(
                    trans_nodes[0].output[0])
            elif utils.is_tf_reverse_op(trans_nodes[0]):
                reverse_nodes = trans_nodes
            else:
                logger.debug("not found reverse op, unexpected")
                return []

            are_all_reverse = all(
                [utils.is_tf_reverse_op(r_op) for r_op in reverse_nodes])
            if are_all_reverse:
                return reverse_nodes

            logger.debug("bw y output is used followed by reverse node")
            return []

        logger.debug("unexpected number of transpose after RNN 1st output:%s",
                     s_cnt)
        return []

    logger.debug("unexpected number of squeeze following RNN 1st output:%s",
                 s_cnt)
    return []
コード例 #2
0
def find_bidirectional_rnns(g, ops, rnn_type):
    """
    Find possible bidirectional rnns, return: list of tuple,
    Format of tuple is (fw onnx rnn node, bw onnx rnn node).
    """
    fw_rnns = defaultdict(list)
    bw_rnns = defaultdict(list)
    for n in g.get_nodes():
        if n.type != onnx_rnn_type_mapping[rnn_type]:
            continue

        input_id = n.input[0]
        temp = n.inputs[0]
        is_bw = False
        if temp.type == "Transpose":
            input_id = temp.input[0]
            temp = temp.inputs[0]

        if utils.is_tf_reverse_op(temp):
            input_id = temp.input[0]
            is_bw = True

        if is_bw:
            # if output 0 is consumed and there is no reverse after the 1st output.
            # it's not backward rnn.
            if g.find_output_consumers(
                    n.output[0]) and not get_reverse_nodes_after_y_output(
                        g, n):
                logger.warning(
                    "rnn %s following Reverse op isn't the part of bi-rnn.",
                    n.name)
                continue

            logger.debug("find bw rnn %s", input_id)
            bw_rnns[input_id].append(n)
        else:
            logger.debug("find fw rnn %s", input_id)
            fw_rnns[input_id].append(n)

    # fw_rnn and bw_rnn must share the same input
    birnn_input = list(set(fw_rnns.keys()).intersection(bw_rnns.keys()))
    bi_rnns = []
    matched_rnn = []
    for inp in birnn_input:
        fw_rnn = fw_rnns[inp]
        bw_rnn = bw_rnns[inp]
        # it's possible several bi-rnns share the same input
        for fw_n in fw_rnn:
            for bw_n in bw_rnn:
                if belong_to_birnn(g, fw_n, bw_n, rnn_type) and \
                        not fw_n in matched_rnn and not bw_n in matched_rnn:
                    logger.debug("found birnn comprising %s and %s", fw_n.name,
                                 bw_n.name)
                    bi_rnns.append((fw_n, bw_n))
                    matched_rnn.extend([fw_n, bw_n])
    return bi_rnns
コード例 #3
0
def rewrite_bidirectional_grus(g, ops):
    """
        return: list of tuple, format of tuple is
        ((fw input_id, fw onnx gru node), (bw input_id, bw onnx gru node)), and fw input_id equals to bw input_id
    """
    fw_gru = {}
    bw_gru = {}
    for n in g.get_nodes():
        if n.type != "GRU":
            continue
        input_id = n.input[0]
        temp = n.inputs[0]
        is_backward_gru = False
        if temp.type == "Transpose":
            input_id = temp.input[0]
            temp = temp.inputs[0]

        if is_tf_reverse_op(temp):
            input_id = temp.input[0]
            is_backward_gru = True

        if is_backward_gru:
            # if output 0 is consumed, and there is no reverse after the gru output.
            # it's not reversed gru
            if g.find_output_consumers(
                    n.output[0]) and not get_reverse_nodes_after_y_output(
                        g, n):
                continue
            logger.debug("find bw gru %s", input_id)
            bw_gru[input_id] = [input_id, n]
        else:
            logger.debug("find fw gru %s", input_id)
            fw_gru[input_id] = [input_id, n]

    # when fw_gru has same input as bw_gru, then it may be a bi gru
    bigru_input = list(set(fw_gru.keys()).intersection(bw_gru.keys()))
    bi_grus = [(fw_gru[input_id], bw_gru[input_id])
               for input_id in bigru_input]

    return process_bigru(g, bi_grus)
コード例 #4
0
def rewrite_bidirectional_lstms(g, ops):
    fw_lstm = {}
    bw_lstm = {}
    for n in g.get_nodes():
        if n.type != "LSTM":
            continue
        input_id = n.input[0]
        temp = n.inputs[0]
        is_backward_lstm = False
        if temp.type == "Transpose":
            input_id = temp.input[0]
            temp = temp.inputs[0]

        if is_tf_reverse_op(temp):
            input_id = temp.input[0]
            is_backward_lstm = True

        if is_backward_lstm:
            # if output 0 is consumed, and there is no reverse after the lstm output.
            # it's not reversed lstm
            if g.find_output_consumers(
                    n.output[0]) and not get_reverse_nodes_after_y_output(
                        g, n):
                continue

            logger.debug("find bw lstm %s", input_id)
            bw_lstm[input_id] = [input_id, n]
        else:
            logger.debug("find fw lstm %s", input_id)
            fw_lstm[input_id] = [input_id, n]

    bilstm_input = list(set(fw_lstm.keys()).intersection(bw_lstm.keys()))
    bi_lstms = [(fw_lstm[input_id], bw_lstm[input_id])
                for input_id in bilstm_input]

    return process_bilstm(g, bi_lstms)