示例#1
0
def get_reverse_nodes_after_y_output(g, lstm_bw):
    bw_consumers = g.find_output_consumers(lstm_bw.output[0])

    # todo: figure out a better way to remove reverse op
    squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"]
    s_cnt = len(squeeze_nodes)
    if s_cnt == 1:
        s = squeeze_nodes[0]
        trans_nodes = g.find_output_consumers(s.output[0])
        if len(trans_nodes) == 1:
            if trans_nodes[0].type == "Transpose":
                reverse_nodes = g.find_output_consumers(
                    trans_nodes[0].output[0])
            elif is_reverse_op(trans_nodes[0]):
                reverse_nodes = trans_nodes
            else:
                log.debug("not found reverse op, unexpected")
                return None

            are_all_reverse = all(
                [is_reverse_op(r_op) for r_op in reverse_nodes])
            if are_all_reverse:
                return reverse_nodes

            log.debug("bw y output is used followed by reverse node")
            return None

        log.debug("unexpected number of transpose after LSTM 1st output:%s",
                  s_cnt)
        return None

    log.debug("unexpected number of squeeze following LSTM 1st output:%s",
              s_cnt)
    return None
示例#2
0
def rewrite_bidirectional_lstms(g, ops):
    fw_lstm = {}
    bw_lstm = {}
    for n in g.get_nodes():
        if n.type != "LSTM":
            continue
        input_id = n.input[0]
        temp = n.inputs[0]
        is_backward_lstm = False
        if temp.type == "Transpose":
            input_id = temp.input[0]
            temp = temp.inputs[0]

        if is_reverse_op(temp):
            input_id = temp.input[0]
            is_backward_lstm = True

        if is_backward_lstm:
            # make sure reverse lstm output will be reversed back
            if get_reverse_nodes_after_y_output(g, n):
                log.debug("find bw lstm %s", input_id)
                bw_lstm[input_id] = [input_id, n]
        else:
            log.debug("find fw lstm %s", input_id)
            fw_lstm[input_id] = [input_id, n]

    bilstm_input = list(set(fw_lstm.keys()).intersection(bw_lstm.keys()))
    bi_lstms = [(fw_lstm[input_id], bw_lstm[input_id])
                for input_id in bilstm_input]

    return process_bilstm(g, bi_lstms)
示例#3
0
def rewrite_bidirectional_grus(g, ops):
    """
        return: list of tuple, format of tuple is
        ((fw input_id, fw onnx gru node), (bw input_id, bw onnx gru node)), and fw input_id equals to bw input_id
    """
    fw_gru = {}
    bw_gru = {}
    for n in g.get_nodes():
        if n.type != "GRU":
            continue
        input_id = n.input[0]
        temp = n.inputs[0]
        is_backward_gru = False
        if temp.type == "Transpose":
            input_id = temp.input[0]
            temp = temp.inputs[0]

        if is_reverse_op(temp):
            input_id = temp.input[0]
            is_backward_gru = True

        if is_backward_gru:
            log.debug("find bw gru %s", input_id)
            bw_gru[input_id] = [input_id, n]
        else:
            log.debug("find fw gru %s", input_id)
            fw_gru[input_id] = [input_id, n]

    # when fw_gru has same input as bw_gru, then it may be a bi gru
    bigru_input = list(set(fw_gru.keys()).intersection(bw_gru.keys()))
    bi_grus = [(fw_gru[input_id], bw_gru[input_id])
               for input_id in bigru_input]

    return process_bigru(g, bi_grus)
def rewrite_bidirectional_lstms(g, ops):
    fw_lstm = {}
    bw_lstm = {}
    for n in g.get_nodes():
        if n.type != "LSTM":
            continue
        input_id = n.input[0]
        temp = n.inputs[0]
        is_backward_lstm = False
        if temp.type == "Transpose":
            input_id = temp.input[0]
            temp = temp.inputs[0]

        if is_reverse_op(temp):
            input_id = temp.input[0]
            is_backward_lstm = True

        if is_backward_lstm:
            # if output 0 is consumed, and there is no reverse after the lstm output.
            # it's not reversed lstm
            if g.find_output_consumers(n.output[0]) and not get_reverse_nodes_after_y_output(g, n):
                continue

            logger.debug("find bw lstm %s", input_id)
            bw_lstm[input_id] = [input_id, n]
        else:
            logger.debug("find fw lstm %s", input_id)
            fw_lstm[input_id] = [input_id, n]

    bilstm_input = list(set(fw_lstm.keys()).intersection(bw_lstm.keys()))
    bi_lstms = [(fw_lstm[input_id], bw_lstm[input_id]) for input_id in bilstm_input]

    return process_bilstm(g, bi_lstms)
示例#5
0
def slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm,
                                             lstm_output_index, all_nodes,
                                             to_remove):
    fw_consumers = g.find_output_consumers(lstm_fw.output[lstm_output_index])
    bw_consumers = g.find_output_consumers(lstm_bw.output[lstm_output_index])

    if lstm_output_index == 0:
        axis = 1
        # remove reverse op for lstm_bw
        # todo: figure out a better way to remove reverse op
        squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"]
        s_cnt = len(squeeze_nodes)
        if s_cnt > 1:
            raise ValueError(
                "unexpected number of squeeze following LSTM 1st output")
        elif s_cnt == 1:
            s = squeeze_nodes[0]
            trans_nodes = g.find_output_consumers(s.output[0])
            if len(trans_nodes) == 1:
                if trans_nodes[0].type == "Transpose":
                    reverse_nodes = g.find_output_consumers(
                        trans_nodes[0].output[0])
                elif is_reverse_op(trans_nodes[0]):
                    reverse_nodes = trans_nodes
                else:
                    raise ValueError("not found reverse op, unexpected")

                for r_op in reverse_nodes:
                    log.debug("remove reverse op called %s", r_op.name)
                    g.replace_all_inputs(all_nodes, r_op.output[0],
                                         r_op.input[0])
                    to_remove.append(r_op.name)
            else:
                raise ValueError(
                    "unexpected number of transpose after LSTM 1st output")
    elif lstm_output_index in [1, 2]:
        axis = 0
    else:
        raise ValueError("LSTM only should has 3 outputs.")

    if fw_consumers:
        attr = {"axes": [axis], "starts": [0], "ends": [1]}
        slice_node_fw = make_onnx_node(g, "Slice",
                                       [bi_lstm.output[lstm_output_index]],
                                       attr)
        all_nodes.append(slice_node_fw)
        g.replace_all_inputs(fw_consumers, lstm_fw.output[lstm_output_index],
                             slice_node_fw.output[0])

    if bw_consumers:
        attr = {"axes": [axis], "starts": [1], "ends": [2]}
        slice_node_bw = make_onnx_node(g, "Slice",
                                       [bi_lstm.output[lstm_output_index]],
                                       attr)
        all_nodes.append(slice_node_bw)
        g.replace_all_inputs(bw_consumers, lstm_bw.output[lstm_output_index],
                             slice_node_bw.output[0])
示例#6
0
def rewrite_bidirectional_grus(g, ops):
    """
        return: list of tuple, format of tuple is
        ((fw input_id, fw onnx gru node), (bw input_id, bw onnx gru node)), and fw input_id equals to bw input_id
    """
    fw_gru = {}
    bw_gru = {}
    for n in g.get_nodes():
        if n.type != "GRU":
            continue
        input_id = n.input[0]
        temp = n.inputs[0]
        is_backward_gru = False
        if temp.type == "Transpose":
            input_id = temp.input[0]
            temp = temp.inputs[0]

        if is_reverse_op(temp):
            input_id = temp.input[0]
            is_backward_gru = True

        if is_backward_gru:
            # if output 0 is consumed, and there is no reverse after the gru output.
            # it's not reversed gru
            if g.find_output_consumers(
                    n.output[0]) and not get_reverse_nodes_after_y_output(
                        g, n):
                continue
            logger.debug("find bw gru %s", input_id)
            bw_gru[input_id] = [input_id, n]
        else:
            logger.debug("find fw gru %s", input_id)
            fw_gru[input_id] = [input_id, n]

    # when fw_gru has same input as bw_gru, then it may be a bi gru
    bigru_input = list(set(fw_gru.keys()).intersection(bw_gru.keys()))
    bi_grus = [(fw_gru[input_id], bw_gru[input_id])
               for input_id in bigru_input]

    return process_bigru(g, bi_grus)