def _process_non_tuple_ch_init_nodes(self, rnn_props):
        input_id = rnn_props.var_initializers["ct_ht"]
        hidden_size = rnn_props.hidden_size

        # todo: remove this once Fill ops is supported
        fill_ch_init_node = self._workaround_fill_ch_init_node(
            input_id, rnn_props)
        if fill_ch_init_node:
            return fill_ch_init_node.output[0], fill_ch_init_node.output[0]

        attr = {"axes": [1], "starts": [0], "ends": [hidden_size]}
        slice_node1 = make_onnx_node(self.g, "Slice", [input_id], attr)
        unsqueeze_node_1 = make_onnx_node(self.g,
                                          "Unsqueeze", [slice_node1.output[0]],
                                          attr={"axes": [0]})

        attr = {
            "axes": [1],
            "starts": [hidden_size],
            "ends": [hidden_size * 2]
        }
        slice_node2 = make_onnx_node(self.g, "Slice", [input_id], attr)
        unsqueeze_node_2 = make_onnx_node(self.g,
                                          "Unsqueeze", [slice_node2.output[0]],
                                          attr={"axes": [0]})

        self.all_nodes.extend(
            [slice_node1, slice_node2, unsqueeze_node_1, unsqueeze_node_2])
        return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
示例#2
0
    def _workaround_fill_ch_init_node(self, initializer_input_id, rnn_props):
        node = self.g.get_node_by_name(initializer_input_id)
        if node.type != "Fill":
            return None

        fill_val = node.inputs[1].get_tensor_value()[0]
        fill_val_dtype = utils.ONNX_TO_NUMPY_DTYPE[node.inputs[1].dtype]

        # this must be int64, since Concat's input data type must be consistent.
        num_direction_node = self.g.make_const(utils.make_name("Const"),
                                               np.array([1], dtype=np.float32))
        h_node = self.g.make_const(
            utils.make_name("Const"),
            np.array([rnn_props.hidden_size], dtype=np.float32))
        b_node = rnn_props.batch_size_node
        # Concat in OPSET7 does not support int64.
        tile_shape = make_onnx_node(
            self.g,
            "Concat",
            [num_direction_node.output[0], b_node.output[0], h_node.output[0]],
            attr={"axis": 0})

        # Tile's repeats must be INT64
        attr = {"to": onnx_pb.TensorProto.INT64}
        tile_shape_int64 = make_onnx_node(self.g, 'Cast',
                                          [tile_shape.output[0]], attr)

        const_node = self.g.make_const(
            utils.make_name("Const"),
            np.array([[[fill_val]]], dtype=fill_val_dtype))
        tile_node = make_onnx_node(
            self.g, 'Tile', [const_node.output[0], tile_shape_int64.output[0]])
        self.all_nodes.extend([tile_shape, tile_shape_int64, tile_node])
        return tile_node
示例#3
0
def slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm,
                                             lstm_output_index, all_nodes,
                                             to_remove):
    fw_consumers = g.find_output_consumers(lstm_fw.output[lstm_output_index])
    bw_consumers = g.find_output_consumers(lstm_bw.output[lstm_output_index])

    if lstm_output_index == 0:
        axis = 1
        # remove reverse op for lstm_bw
        # todo: figure out a better way to remove reverse op
        squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"]
        s_cnt = len(squeeze_nodes)
        if s_cnt > 1:
            raise ValueError(
                "unexpected number of squeeze following LSTM 1st output")
        elif s_cnt == 1:
            s = squeeze_nodes[0]
            trans_nodes = g.find_output_consumers(s.output[0])
            if len(trans_nodes) == 1:
                if trans_nodes[0].type == "Transpose":
                    reverse_nodes = g.find_output_consumers(
                        trans_nodes[0].output[0])
                elif is_reverse_op(trans_nodes[0]):
                    reverse_nodes = trans_nodes
                else:
                    raise ValueError("not found reverse op, unexpected")

                for r_op in reverse_nodes:
                    log.debug("remove reverse op called %s", r_op.name)
                    g.replace_all_inputs(all_nodes, r_op.output[0],
                                         r_op.input[0])
                    to_remove.append(r_op.name)
            else:
                raise ValueError(
                    "unexpected number of transpose after LSTM 1st output")
    elif lstm_output_index in [1, 2]:
        axis = 0
    else:
        raise ValueError("LSTM only should has 3 outputs.")

    if fw_consumers:
        attr = {"axes": [axis], "starts": [0], "ends": [1]}
        slice_node_fw = make_onnx_node(g, "Slice",
                                       [bi_lstm.output[lstm_output_index]],
                                       attr)
        all_nodes.append(slice_node_fw)
        g.replace_all_inputs(fw_consumers, lstm_fw.output[lstm_output_index],
                             slice_node_fw.output[0])

    if bw_consumers:
        attr = {"axes": [axis], "starts": [1], "ends": [2]}
        slice_node_bw = make_onnx_node(g, "Slice",
                                       [bi_lstm.output[lstm_output_index]],
                                       attr)
        all_nodes.append(slice_node_bw)
        g.replace_all_inputs(bw_consumers, lstm_bw.output[lstm_output_index],
                             slice_node_bw.output[0])
    def _connect_lstm_output_to_graph(self, lstm_node, exit_node, rnn_props):
        exit_consumers = self.g.find_output_consumers(exit_node.output[0])
        gather_node = self._validate_output_exit_consumers(exit_consumers)
        if len(exit_consumers) != 2 or not gather_node:
            log.debug("lstm output exit node has %d consumers",
                      len(exit_consumers))
            raise ValueError("lstm output exit node check failed")

        # gather output for sure has shape [time, batch, hidden]
        gather_output_id = gather_node.output[0]
        log.debug("found output ta gather node %s", gather_output_id)
        # in tf batch major mode, output shape is : [batch, time, hidden]
        # in time major mode, output shape is: [time, batch, hidden]
        # in onnx, output shape is : [time, num_directions, batch, hidden]

        output_id = lstm_node.output[0]
        squeeze_node = make_onnx_node(self.g,
                                      "Squeeze", [output_id],
                                      attr={"axes": [1]})
        lstm_output_shape = self.g.get_shape(output_id)
        self.g.set_shape(
            squeeze_node.output[0],
            [lstm_output_shape[0], lstm_output_shape[2], lstm_output_shape[3]])
        self.g.set_dtype(squeeze_node.output[0], self.g.get_dtype(output_id))

        if not rnn_props.time_major:
            gather_consumers = self.g.find_output_consumers(gather_output_id)
            gather_trans_consumers = [
                n for n in gather_consumers if check_is_timemajor_transpose(n)
            ]
            if len(gather_trans_consumers) != 1:
                raise ValueError(
                    "batch major should expect a transpose after gather")
            trans = gather_trans_consumers[0]  # trans has rnn scope name

            # we just check the transpose here, but will not re-use it, because
            # it may hold non-const perms. so we re-create a new transpose to replace it
            attr = {"perm": np.array([1, 0, 2], dtype=np.int64)}
            new_trans = make_onnx_node(self.g, "Transpose",
                                       [squeeze_node.output[0]], attr)
            trans_input_shape = self.g.get_shape(squeeze_node.output[0])
            self.g.replace_all_inputs(self.all_nodes, trans.output[0],
                                      new_trans.output[0])
            self.g.set_shape(new_trans.output[0], [
                trans_input_shape[1], trans_input_shape[0],
                trans_input_shape[2]
            ])
            self.g.set_dtype(new_trans.output[0],
                             self.g.get_dtype(squeeze_node.output[0]))
            self.all_nodes.extend([new_trans])

        self.g.replace_all_inputs(self.all_nodes, gather_output_id,
                                  squeeze_node.output[0])
        self.all_nodes.extend([squeeze_node])
示例#5
0
    def create_rnn_node(self, rnn_props):
        # specify if the RNN is forward, reverse, or bidirectional.
        # Must be one of forward (default), reverse, or bidirectional.
        # Here we won't mark bidirectional/reverse, we will have another rewriter running after this one,
        # which will based on patterns to combine a forward GRU and a backward GRU into a bidirectional one.
        direction = "forward"
        num_direction = 1
        # todo: input_forget
        attr = {
            "direction": direction,
            "hidden_size": rnn_props.hidden_size,
            "activations": ["sigmoid", rnn_props.activation]
        }
        inputs = rnn_props.onnx_input_ids
        gru_inputs = [
            inputs["X"], inputs["W"], inputs["R"], inputs["B"],
            inputs["sequence_lens"], inputs["initial_state"]
        ]
        gru_node = make_onnx_node(self.g, "GRU", gru_inputs, attr, 2)

        x_shape = self.g.get_shape(gru_node.input[0])
        x_seq_length = x_shape[0]
        x_batch_size = x_shape[1]
        self.g.set_shape(
            gru_node.output[0],
            [x_seq_length, num_direction, x_batch_size, rnn_props.hidden_size])
        self.g.set_shape(gru_node.output[1],
                         [num_direction, x_batch_size, rnn_props.hidden_size])
        return gru_node
    def create_rnn_node(self, rnn_props):
        # specify if the RNN is forward, reverse, or bidirectional.
        # Must be one of forward (default), reverse, or bidirectional.
        # Here we won't mark bidirectional/reverse, we will have another rewriter running
        # after this one, which will based on patterns to combine a forward LSTM and a
        # backward LSTM into a bidirectional one.
        direction = "forward"
        num_direction = 1
        # todo: input_forget
        attr = {"direction": direction, "hidden_size": rnn_props.hidden_size}
        inputs = rnn_props.onnx_input_ids
        lstm_inputs = [
            inputs["X"], inputs["W"], inputs["R"], inputs["B"],
            inputs["sequence_lens"], inputs["initial_h"], inputs["initial_c"]
        ]
        lstm_node = make_onnx_node(self.g, "LSTM", lstm_inputs, attr, 3)

        x_shape = self.g.get_shape(lstm_node.input[0])
        x_seq_length = x_shape[0]
        x_batch_size = x_shape[1]
        out_dtype = self.g.get_dtype(inputs["X"])
        self.g.set_shape(
            lstm_node.output[0],
            [x_seq_length, num_direction, x_batch_size, rnn_props.hidden_size])
        self.g.set_dtype(lstm_node.output[0], out_dtype)
        self.g.set_shape(lstm_node.output[1],
                         [num_direction, x_batch_size, rnn_props.hidden_size])
        self.g.set_dtype(lstm_node.output[1], out_dtype)
        self.g.copy_shape(lstm_node.output[1], lstm_node.output[2])
        self.g.set_dtype(lstm_node.output[2], out_dtype)
        return lstm_node
    def process_seq_length(self, rnn_props, seq_length_node):
        # output: [time step, batch size, input size]
        shape_node = make_onnx_node(self.g, "Shape", [rnn_props.x_input_id])

        # LSTMCell only allow inputs of [batch size, input_size], so we assume dynamic_rnn has 3 dims.
        # Slice cannot support Int64 in OPSET 7, so we cast here.
        attr = {"to": onnx_pb.TensorProto.FLOAT}
        cast_shape_node = make_onnx_node(self.g, "Cast", [shape_node.output[0]], attr)
        self.g.copy_shape(shape_node.output[0], cast_shape_node.output[0])

        attr = {"axes": [0], "starts": [1], "ends": [2]}
        batchsize_node = make_onnx_node(self.g, "Slice", [cast_shape_node.output[0]], attr)

        # Tile's repeats must be INT64
        attr = {"to": onnx_pb.TensorProto.INT64}
        repeat_node = make_onnx_node(self.g, 'Cast', [batchsize_node.output[0]], attr)

        self.all_nodes.extend([shape_node, cast_shape_node, batchsize_node, repeat_node])

        if not seq_length_node:
            attr = {"axes": [0], "starts": [0], "ends": [1]}
            timestep_node = make_onnx_node(self.g, 'Slice', [cast_shape_node.output[0]], attr)

            tile_node = make_onnx_node(self.g, 'Tile', [timestep_node.output[0], repeat_node.output[0]])

            attr = {"to": onnx_pb.TensorProto.INT32}  # LSTM sequence_lens needs to be int32
            seq_length_node = make_onnx_node(self.g, 'Cast', [tile_node.output[0]], attr)

            self.all_nodes.extend([timestep_node, tile_node, seq_length_node])

        rnn_props.onnx_input_ids["sequence_lens"] = seq_length_node.output[0]
        return seq_length_node, batchsize_node
示例#8
0
 def _connect_gru_state_to_graph(self, gru_node, exit_node, rnn_props):
     # in tf, state output shape is: [batch, hidden]
     # in onnx, output shape is: [number_directions, batch, hidden]
     output_id = gru_node.output[1]
     squeeze_node = make_onnx_node(self.g, "Squeeze", [output_id], attr={"axes": [0]})
     gru_state_shape = self.g.get_shape(output_id)
     self.g.set_shape(squeeze_node.output[0], [gru_state_shape[1], gru_state_shape[2]])
     self.all_nodes.extend([squeeze_node])
     self.g.replace_all_inputs(self.all_nodes, exit_node.output[0], squeeze_node.output[0])
    def _create_scan_node(self, context, scan_props):
        log.debug("create scan node")
        # here we did not give the sequence_length, because
        # current batch size is 1, not original batch size
        # original seq_length will be used by the loop body of Scan op.
        scan_node = make_onnx_node(
            self.g,
            "Scan", [""] + scan_props.initial_state_and_scan_inputs,
            attr={"num_scan_inputs": len(scan_props.loop_scan_inputs)},
            output_count=len(scan_props.loop_state_outputs +
                             scan_props.loop_scan_outputs),
            skip_conversion=True)

        # the first state var is time-iterator.
        index = 0
        time_input_shape = self.g.get_shape(scan_node.input[1])
        time_input_dtype = self.g.get_dtype(scan_node.input[1])

        log.debug(
            "_create_scan_node - set scan state_output shape for %s[%s]:%s",
            scan_node.name, index, time_input_shape)
        self.g.set_shape(scan_node.output[index], time_input_shape)
        self.g.set_dtype(scan_node.output[index], time_input_dtype)
        index += 1

        # for other state vars
        state_input_shape = self.g.get_shape(scan_node.input[2])
        state_input_dtype = self.g.get_dtype(scan_node.input[2])
        for i in range(len(scan_props.loop_state_outputs) - 1):
            log.debug(
                "_create_scan_node - set scan state_output shape for %s[%s]:%s",
                scan_node.name, index, state_input_shape)
            self.g.set_shape(scan_node.output[index], state_input_shape)
            self.g.set_dtype(scan_node.output[index], state_input_dtype)
            index += 1

        last_scan_input_shape = self.g.get_shape(scan_node.input[-1])
        batch = last_scan_input_shape[0]  # should be 1
        time = last_scan_input_shape[1]
        for i in range(len(scan_props.loop_scan_outputs)):
            scan_out_dtype = self.g.get_dtype(scan_props.loop_scan_outputs[i])
            output_shape = self.g.get_shape(scan_props.loop_scan_outputs[i])
            scan_output_shape = [batch, time] + output_shape
            log.debug(
                "scan output [%s] has shape %s, batch:%s, time: %s, cell output shape: %s",
                scan_props.loop_scan_outputs[i], scan_output_shape, batch,
                time, output_shape)
            log.debug(
                "_create_scan_node - set scan scan_output shape for %s[%s]:%s",
                scan_node.name, index, scan_output_shape)
            self.g.set_shape(scan_node.output[index], scan_output_shape)
            self.g.set_dtype(scan_node.output[index], scan_out_dtype)
            index += 1

        return scan_node
示例#10
0
def slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm,
                                             lstm_output_index, all_nodes,
                                             to_remove):
    fw_consumers = g.find_output_consumers(lstm_fw.output[lstm_output_index])
    bw_consumers = g.find_output_consumers(lstm_bw.output[lstm_output_index])

    if lstm_output_index == 0:
        axis = 1
        # remove reverse op for lstm_bw
        reverse_nodes = get_reverse_nodes_after_y_output(g, lstm_bw)
        if not reverse_nodes:
            raise ValueError(
                "should not happen y_output is not followed with reverse node")

        for r_op in reverse_nodes:
            log.debug("remove reverse op called %s", r_op.name)
            g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0])
            to_remove.append(r_op.name)
    elif lstm_output_index in [1, 2]:
        axis = 0
    else:
        raise ValueError("LSTM only should has 3 outputs.")

    if fw_consumers:
        attr = {"axes": [axis], "starts": [0], "ends": [1]}
        slice_node_fw = make_onnx_node(g, "Slice",
                                       [bi_lstm.output[lstm_output_index]],
                                       attr)
        all_nodes.append(slice_node_fw)
        g.replace_all_inputs(fw_consumers, lstm_fw.output[lstm_output_index],
                             slice_node_fw.output[0])

    if bw_consumers:
        attr = {"axes": [axis], "starts": [1], "ends": [2]}
        slice_node_bw = make_onnx_node(g, "Slice",
                                       [bi_lstm.output[lstm_output_index]],
                                       attr)
        all_nodes.append(slice_node_bw)
        g.replace_all_inputs(bw_consumers, lstm_bw.output[lstm_output_index],
                             slice_node_bw.output[0])
    def _convert_timemajor_transpose(self, node):
        if not check_is_timemajor_transpose(node):
            log.debug("not found timemajor transpose")
            return None

        log.debug("found timemajor transpose")

        attr = {"perm": np.array([1, 0, 2], dtype=np.int64)}
        new_trans = make_onnx_node(self.g, "Transpose", [node.input[0]], attr)

        self.g.copy_shape(node.output[0], new_trans.output[0])
        self.g.replace_all_inputs(self.g.get_nodes(), node.output[0], new_trans.output[0])
        return new_trans
示例#12
0
    def _connect_lstm_ych_to_graph(self, lstm_node, exit_node, rnn_props):
        # in tf, concat of y_c and y_h output shape is: [batch, hidden *2]
        # in onnx, y_c/y_h output shape is: [number_directions, batch, hidden]

        concat = make_onnx_node(self.g,
                                "Concat",
                                [lstm_node.output[2], lstm_node.output[1]],
                                attr={"axis": 2})
        yc_shape = self.g.get_shape(lstm_node.output[2])
        self.g.set_shape(concat.output[0],
                         [yc_shape[0], yc_shape[1], yc_shape[2] * 2])

        squeeze_node = make_onnx_node(self.g,
                                      "Squeeze", [concat.output[0]],
                                      attr={"axes": [0]})
        concat_shape = self.g.get_shape(concat.output[0])
        self.g.set_shape(squeeze_node.output[0],
                         [concat_shape[1], concat_shape[2]])
        self.all_nodes.extend([concat, squeeze_node])

        self.g.replace_all_inputs(self.all_nodes, exit_node.output[0],
                                  squeeze_node.output[0])
    def _create_squeeze_node(self, target_name, input_id):
        squeeze_node = make_onnx_node(self.g,
                                      "Squeeze", [input_id],
                                      attr={"axes": [0]},
                                      skip_conversion=True,
                                      op_name_scope=target_name)
        input_shape = self.g.get_shape(input_id)
        if input_shape is None:
            raise ValueError(input_id + " is none")
        input_shape = list(input_shape)[1:]
        self.g.set_shape(squeeze_node.output[0], input_shape)
        self.g.set_dtype(squeeze_node.output[0], self.g.get_dtype(input_id))

        return squeeze_node
示例#14
0
def _process_single_init_node(g, fw_init_input_id, bw_init_input_id,
                              to_append):
    fw_init_is_const, init_fw_val = check_const(g, fw_init_input_id)
    bw_init_is_const, init_bw_val = check_const(g, bw_init_input_id)
    if fw_init_is_const and bw_init_is_const:
        initial_val = np.concatenate((init_fw_val, init_bw_val), axis=0)
        init_name = utils.make_name("initial")
        init_node = g.make_const(init_name, initial_val, skip_conversion=True)
    else:
        attr = {"axis": 0}
        init_node = make_onnx_node(g, "Concat",
                                   [fw_init_input_id, bw_init_input_id], attr)
        to_append.append(init_node)

    return init_node
示例#15
0
    def _process_init_nodes(self, initializer_input_id, rnn_props):
        # copy from  lstm_rewriter
        # todo: remove this once Fill ops is supported
        fill_ch_init_node = self._workaround_fill_ch_init_node(initializer_input_id, rnn_props)
        if fill_ch_init_node:
            return fill_ch_init_node.output[0]

        node = self.g.get_node_by_name(initializer_input_id)
        if node.is_const():
            val = node.get_tensor_value()
            initial_name = utils.make_name("Const")
            new_val = np.expand_dims(val, axis=0)
            const_node = self.g.make_const(initial_name, new_val)
            return const_node.output[0]
        squeeze_node = make_onnx_node(self.g, "Unsqueeze", [initializer_input_id], attr={"axes": [0]})
        self.g.replace_all_inputs(self.g.get_nodes(), initializer_input_id, squeeze_node.output[0])
        self.all_nodes.append(squeeze_node)
        return squeeze_node.output[0]
示例#16
0
def process_bilstm(g, bi_lstms):
    for fw, bw in bi_lstms:
        input_id = fw[0]
        log.debug("=========================")
        log.debug("start handling potential bidirectional lstm %s", input_id)

        lstm_fw = fw[1]
        lstm_bw = bw[1]

        w_fw = get_np_val_for_const(g, lstm_fw, 1)
        w_bw = get_np_val_for_const(g, lstm_bw, 1)
        r_fw = get_np_val_for_const(g, lstm_fw, 2)
        r_bw = get_np_val_for_const(g, lstm_bw, 2)
        b_fw = get_np_val_for_const(g, lstm_fw, 3)
        b_bw = get_np_val_for_const(g, lstm_bw, 3)
        W = np.concatenate((w_fw, w_bw), axis=0)
        R = np.concatenate((r_fw, r_bw), axis=0)
        B = np.concatenate((b_fw, b_bw), axis=0)

        all_nodes = g.get_nodes()
        if len(lstm_fw.inputs) == len(lstm_bw.inputs):
            if len(lstm_fw.inputs) > 4:
                h_node, c_node = process_ch_init_nodes(g, lstm_fw, lstm_bw,
                                                       all_nodes)
        else:
            log.error("fw, bw lstm inputs num is not consistent. stop")
            continue

        # create node
        w_name = utils.make_name("W")
        w_node = g.make_const(w_name, W, skip_conversion=True)

        r_name = utils.make_name("R")
        r_node = g.make_const(r_name, R, skip_conversion=True)

        b_name = utils.make_name("B")
        b_node = g.make_const(b_name, B, skip_conversion=True)
        lstm_inputs = [
            lstm_fw.input[0], w_node.output[0], r_node.output[0],
            b_node.output[0]
        ]
        if len(lstm_fw.inputs) > 4:
            lstm_inputs.extend(
                [lstm_fw.input[4], h_node.output[0], c_node.output[0]])

        direction = "bidirectional"
        if lstm_fw.get_attr("hidden_size").i == lstm_bw.get_attr(
                "hidden_size").i:
            hidden_size = lstm_fw.get_attr("hidden_size").i
        else:
            log.error("fw and bw has different hidden_size, skip")
            continue

        attr = {"direction": direction, "hidden_size": hidden_size}
        bi_lstm_node = make_onnx_node(g,
                                      "LSTM",
                                      lstm_inputs,
                                      attr=attr,
                                      output_count=3)
        all_nodes.append(bi_lstm_node)
        log.debug("processing output nodes")

        to_remove = [
            lstm_fw.name, lstm_fw.input[1], lstm_fw.input[2], lstm_fw.input[3],
            lstm_bw.name, lstm_bw.input[1], lstm_bw.input[2], lstm_bw.input[3]
        ]
        slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw,
                                                 bi_lstm_node, 0, all_nodes,
                                                 to_remove)
        slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw,
                                                 bi_lstm_node, 1, all_nodes,
                                                 to_remove)
        slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw,
                                                 bi_lstm_node, 2, all_nodes,
                                                 to_remove)

        lstm_bw_old_x = lstm_bw.input[0]
        new_nodes = []
        for n in all_nodes:
            if n.name not in to_remove:
                new_nodes.append(n)
        g.set_nodes(new_nodes)

        old_x_consumers = g.find_output_consumers(lstm_bw_old_x)
        # the transpose/reverse here must be followed by LSTM if it is still useful.
        # this is guaranteed by dynamic_rnn logic.
        old_x_has_lstm_as_consumer = [
            n for n in old_x_consumers if n.type == "LSTM"
        ]
        if not old_x_has_lstm_as_consumer:
            log.debug("plan to remove useless reverse op in bw")
            reverse_node = g.get_node_by_name(lstm_bw_old_x)

            if reverse_node.type == "Transpose":
                reverse_node = reverse_node.inputs[0]

            g.replace_all_inputs(g.get_nodes(), reverse_node.output[0],
                                 reverse_node.input[0])
            new_nodes = g.get_nodes()
            new_nodes.remove(reverse_node)
            g.set_nodes(new_nodes)
        else:
            raise ValueError(
                "Reverse is still used by LSTM as input, cannot remove")

    g.update_proto()
    return g.get_nodes()
    def _adapt_scan_sequence_input_or_output(self,
                                             target_name,
                                             input_id,
                                             handle_output=False):
        nodes_to_add = []
        shape_node = make_onnx_node(self.g, "Shape", [input_id])
        nodes_to_add.append(shape_node)
        inferred_shape = self.g.get_shape(input_id)
        if handle_output is True:
            # handle output:
            # if required dim values don't contain more than one -1,
            # just use a const for Reshape's shape input.
            if inferred_shape is not None and inferred_shape[1:].count(
                    -1) <= 1:
                new_shape_node = self.g.make_const(
                    utils.make_name(target_name + "_target_shape"),
                    np.array(inferred_shape[1:], dtype=np.int64))
            else:
                # otherwise, get the dim dynamically, e.g. remove the fake batch size (e.g.1)
                # from [1, time, real-batch, ...]
                origin_shape_node = make_onnx_node(
                    self.g, "Cast", [shape_node.output[0]],
                    {"to": onnx_pb.TensorProto.FLOAT})
                nodes_to_add.append(origin_shape_node)

                sliced_shape_node = make_onnx_node(
                    self.g, "Slice", [origin_shape_node.output[0]], {
                        "axes": [0],
                        "starts": [1],
                        "ends": [sys.maxsize]
                    })
                nodes_to_add.append(sliced_shape_node)

                new_shape_node = make_onnx_node(
                    self.g, "Cast", [sliced_shape_node.output[0]],
                    {"to": onnx_pb.TensorProto.INT64})
                nodes_to_add.append(new_shape_node)

            new_shape = inferred_shape[1:]
        else:
            # handle input:
            if inferred_shape is not None and inferred_shape.count(-1) <= 1:
                new_shape_node = self.g.make_const(
                    utils.make_name(target_name + "_target_shape"),
                    np.array([1] + inferred_shape, dtype=np.int64))
            else:
                # add a fake batch size : 1
                fake_batch_size_node = self.g.make_const(
                    utils.make_name(target_name + "_target_shape"),
                    np.array([
                        1,
                    ], dtype=np.int64))
                new_shape_node = make_onnx_node(
                    self.g, "Concat",
                    [fake_batch_size_node.output[0], shape_node.output[0]],
                    {"axis": 0})
                nodes_to_add.append(new_shape_node)
            new_shape = [1] + inferred_shape

        reshape_node = make_onnx_node(self.g,
                                      "Reshape",
                                      [input_id, new_shape_node.output[0]],
                                      skip_conversion=True,
                                      op_name_scope=target_name)
        nodes_to_add.append(reshape_node)
        self.g.set_shape(reshape_node.output[0], new_shape)
        self.g.set_dtype(reshape_node.output[0], self.g.get_dtype(input_id))
        log.debug("create Reshape for scan output %s, with output shape %s",
                  reshape_node.output[0], new_shape)
        return nodes_to_add
示例#18
0
def process_bigru(g, bi_grus):
    for fw, bw in bi_grus:
        input_id = fw[0]
        log.debug("=========================")
        log.debug("start handling potential bidirectional gru %s", input_id)

        gru_fw = fw[1]
        gru_bw = bw[1]

        w_fw = get_np_val_for_const(g, gru_fw, 1)
        w_bw = get_np_val_for_const(g, gru_bw, 1)
        r_fw = get_np_val_for_const(g, gru_fw, 2)
        r_bw = get_np_val_for_const(g, gru_bw, 2)
        b_fw = get_np_val_for_const(g, gru_fw, 3)
        b_bw = get_np_val_for_const(g, gru_bw, 3)
        W = np.concatenate((w_fw, w_bw), axis=0)
        R = np.concatenate((r_fw, r_bw), axis=0)
        B = np.concatenate((b_fw, b_bw), axis=0)

        all_nodes = g.get_nodes()
        if len(gru_fw.inputs) == len(gru_bw.inputs):
            if len(gru_fw.inputs) > 4:
                initializer_node = process_init_nodes(g, gru_fw, gru_bw,
                                                      all_nodes)
        else:
            log.error("fw, bw gru inputs num is not consistent. stop")
            continue

        # create node
        w_name = utils.make_name("W")
        w_node = g.make_const(w_name, W, skip_conversion=True)

        r_name = utils.make_name("R")
        r_node = g.make_const(r_name, R, skip_conversion=True)

        b_name = utils.make_name("B")
        b_node = g.make_const(b_name, B, skip_conversion=True)
        gru_inputs = [
            gru_fw.input[0], w_node.output[0], r_node.output[0],
            b_node.output[0]
        ]
        if len(gru_fw.inputs) > 4:
            gru_inputs.extend([gru_fw.input[4], initializer_node.output[0]])

        direction = "bidirectional"
        if gru_fw.get_attr("hidden_size").i == gru_bw.get_attr(
                "hidden_size").i:
            hidden_size = gru_fw.get_attr("hidden_size").i
        else:
            log.error("fw and bw has different hidden_size, skip")
            continue
        # activation has to be took care
        # attr here is proto, and make_onnx_node needs dict
        activations = [
            act.decode("utf-8")
            for act in gru_fw.get_attr("activations").strings
        ]
        activations += [
            act.decode("utf-8")
            for act in gru_bw.get_attr("activations").strings
        ]
        attr = {
            "direction": direction,
            "hidden_size": hidden_size,
            "activations": activations
        }
        bi_gru_node = make_onnx_node(g,
                                     "GRU",
                                     gru_inputs,
                                     attr=attr,
                                     output_count=2)
        all_nodes.append(bi_gru_node)
        log.debug("processing output nodes")

        to_remove = [
            gru_fw.name, gru_fw.input[1], gru_fw.input[2], gru_fw.input[3],
            gru_bw.name, gru_bw.input[1], gru_bw.input[2], gru_bw.input[3]
        ]
        slice_bilstm_for_original_lstm_consumers(g, gru_fw, gru_bw,
                                                 bi_gru_node, 0, all_nodes,
                                                 to_remove)
        slice_bilstm_for_original_lstm_consumers(g, gru_fw, gru_bw,
                                                 bi_gru_node, 1, all_nodes,
                                                 to_remove)

        gru_bw_old_x = gru_bw.input[0]
        new_nodes = []
        for n in all_nodes:
            if n.name not in to_remove:
                new_nodes.append(n)
        g.set_nodes(new_nodes)

        old_x_consumers = g.find_output_consumers(gru_bw_old_x)
        # the transpose/reverse here must be followed by GRU if it is still useful.
        # this is guaranteed by dynamic_rnn logic.
        old_x_has_gru_as_consumer = [
            n for n in old_x_consumers if n.type == "GRU"
        ]
        if not old_x_has_gru_as_consumer:
            log.debug("plan to remove useless reverse op in bw")
            reverse_node = g.get_node_by_name(gru_bw_old_x)

            if reverse_node.type == "Transpose":
                reverse_node = reverse_node.inputs[0]

            g.replace_all_inputs(g.get_nodes(), reverse_node.output[0],
                                 reverse_node.input[0])
            new_nodes = g.get_nodes()
            new_nodes.remove(reverse_node)
            g.set_nodes(new_nodes)
        else:
            raise ValueError(
                "Reverse is still used by GRU as input, cannot remove")

    g.update_proto()
    return g.get_nodes()