def _process_non_tuple_ch_init_nodes(self, rnn_props): input_id = rnn_props.var_initializers["ct_ht"] hidden_size = rnn_props.hidden_size # todo: remove this once Fill ops is supported fill_ch_init_node = self._workaround_fill_ch_init_node( input_id, rnn_props) if fill_ch_init_node: return fill_ch_init_node.output[0], fill_ch_init_node.output[0] attr = {"axes": [1], "starts": [0], "ends": [hidden_size]} slice_node1 = make_onnx_node(self.g, "Slice", [input_id], attr) unsqueeze_node_1 = make_onnx_node(self.g, "Unsqueeze", [slice_node1.output[0]], attr={"axes": [0]}) attr = { "axes": [1], "starts": [hidden_size], "ends": [hidden_size * 2] } slice_node2 = make_onnx_node(self.g, "Slice", [input_id], attr) unsqueeze_node_2 = make_onnx_node(self.g, "Unsqueeze", [slice_node2.output[0]], attr={"axes": [0]}) self.all_nodes.extend( [slice_node1, slice_node2, unsqueeze_node_1, unsqueeze_node_2]) return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
def _workaround_fill_ch_init_node(self, initializer_input_id, rnn_props): node = self.g.get_node_by_name(initializer_input_id) if node.type != "Fill": return None fill_val = node.inputs[1].get_tensor_value()[0] fill_val_dtype = utils.ONNX_TO_NUMPY_DTYPE[node.inputs[1].dtype] # this must be int64, since Concat's input data type must be consistent. num_direction_node = self.g.make_const(utils.make_name("Const"), np.array([1], dtype=np.float32)) h_node = self.g.make_const( utils.make_name("Const"), np.array([rnn_props.hidden_size], dtype=np.float32)) b_node = rnn_props.batch_size_node # Concat in OPSET7 does not support int64. tile_shape = make_onnx_node( self.g, "Concat", [num_direction_node.output[0], b_node.output[0], h_node.output[0]], attr={"axis": 0}) # Tile's repeats must be INT64 attr = {"to": onnx_pb.TensorProto.INT64} tile_shape_int64 = make_onnx_node(self.g, 'Cast', [tile_shape.output[0]], attr) const_node = self.g.make_const( utils.make_name("Const"), np.array([[[fill_val]]], dtype=fill_val_dtype)) tile_node = make_onnx_node( self.g, 'Tile', [const_node.output[0], tile_shape_int64.output[0]]) self.all_nodes.extend([tile_shape, tile_shape_int64, tile_node]) return tile_node
def slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm, lstm_output_index, all_nodes, to_remove): fw_consumers = g.find_output_consumers(lstm_fw.output[lstm_output_index]) bw_consumers = g.find_output_consumers(lstm_bw.output[lstm_output_index]) if lstm_output_index == 0: axis = 1 # remove reverse op for lstm_bw # todo: figure out a better way to remove reverse op squeeze_nodes = [c for c in bw_consumers if c.type == "Squeeze"] s_cnt = len(squeeze_nodes) if s_cnt > 1: raise ValueError( "unexpected number of squeeze following LSTM 1st output") elif s_cnt == 1: s = squeeze_nodes[0] trans_nodes = g.find_output_consumers(s.output[0]) if len(trans_nodes) == 1: if trans_nodes[0].type == "Transpose": reverse_nodes = g.find_output_consumers( trans_nodes[0].output[0]) elif is_reverse_op(trans_nodes[0]): reverse_nodes = trans_nodes else: raise ValueError("not found reverse op, unexpected") for r_op in reverse_nodes: log.debug("remove reverse op called %s", r_op.name) g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0]) to_remove.append(r_op.name) else: raise ValueError( "unexpected number of transpose after LSTM 1st output") elif lstm_output_index in [1, 2]: axis = 0 else: raise ValueError("LSTM only should has 3 outputs.") if fw_consumers: attr = {"axes": [axis], "starts": [0], "ends": [1]} slice_node_fw = make_onnx_node(g, "Slice", [bi_lstm.output[lstm_output_index]], attr) all_nodes.append(slice_node_fw) g.replace_all_inputs(fw_consumers, lstm_fw.output[lstm_output_index], slice_node_fw.output[0]) if bw_consumers: attr = {"axes": [axis], "starts": [1], "ends": [2]} slice_node_bw = make_onnx_node(g, "Slice", [bi_lstm.output[lstm_output_index]], attr) all_nodes.append(slice_node_bw) g.replace_all_inputs(bw_consumers, lstm_bw.output[lstm_output_index], slice_node_bw.output[0])
def _connect_lstm_output_to_graph(self, lstm_node, exit_node, rnn_props): exit_consumers = self.g.find_output_consumers(exit_node.output[0]) gather_node = self._validate_output_exit_consumers(exit_consumers) if len(exit_consumers) != 2 or not gather_node: log.debug("lstm output exit node has %d consumers", len(exit_consumers)) raise ValueError("lstm output exit node check failed") # gather output for sure has shape [time, batch, hidden] gather_output_id = gather_node.output[0] log.debug("found output ta gather node %s", gather_output_id) # in tf batch major mode, output shape is : [batch, time, hidden] # in time major mode, output shape is: [time, batch, hidden] # in onnx, output shape is : [time, num_directions, batch, hidden] output_id = lstm_node.output[0] squeeze_node = make_onnx_node(self.g, "Squeeze", [output_id], attr={"axes": [1]}) lstm_output_shape = self.g.get_shape(output_id) self.g.set_shape( squeeze_node.output[0], [lstm_output_shape[0], lstm_output_shape[2], lstm_output_shape[3]]) self.g.set_dtype(squeeze_node.output[0], self.g.get_dtype(output_id)) if not rnn_props.time_major: gather_consumers = self.g.find_output_consumers(gather_output_id) gather_trans_consumers = [ n for n in gather_consumers if check_is_timemajor_transpose(n) ] if len(gather_trans_consumers) != 1: raise ValueError( "batch major should expect a transpose after gather") trans = gather_trans_consumers[0] # trans has rnn scope name # we just check the transpose here, but will not re-use it, because # it may hold non-const perms. so we re-create a new transpose to replace it attr = {"perm": np.array([1, 0, 2], dtype=np.int64)} new_trans = make_onnx_node(self.g, "Transpose", [squeeze_node.output[0]], attr) trans_input_shape = self.g.get_shape(squeeze_node.output[0]) self.g.replace_all_inputs(self.all_nodes, trans.output[0], new_trans.output[0]) self.g.set_shape(new_trans.output[0], [ trans_input_shape[1], trans_input_shape[0], trans_input_shape[2] ]) self.g.set_dtype(new_trans.output[0], self.g.get_dtype(squeeze_node.output[0])) self.all_nodes.extend([new_trans]) self.g.replace_all_inputs(self.all_nodes, gather_output_id, squeeze_node.output[0]) self.all_nodes.extend([squeeze_node])
def create_rnn_node(self, rnn_props): # specify if the RNN is forward, reverse, or bidirectional. # Must be one of forward (default), reverse, or bidirectional. # Here we won't mark bidirectional/reverse, we will have another rewriter running after this one, # which will based on patterns to combine a forward GRU and a backward GRU into a bidirectional one. direction = "forward" num_direction = 1 # todo: input_forget attr = { "direction": direction, "hidden_size": rnn_props.hidden_size, "activations": ["sigmoid", rnn_props.activation] } inputs = rnn_props.onnx_input_ids gru_inputs = [ inputs["X"], inputs["W"], inputs["R"], inputs["B"], inputs["sequence_lens"], inputs["initial_state"] ] gru_node = make_onnx_node(self.g, "GRU", gru_inputs, attr, 2) x_shape = self.g.get_shape(gru_node.input[0]) x_seq_length = x_shape[0] x_batch_size = x_shape[1] self.g.set_shape( gru_node.output[0], [x_seq_length, num_direction, x_batch_size, rnn_props.hidden_size]) self.g.set_shape(gru_node.output[1], [num_direction, x_batch_size, rnn_props.hidden_size]) return gru_node
def create_rnn_node(self, rnn_props): # specify if the RNN is forward, reverse, or bidirectional. # Must be one of forward (default), reverse, or bidirectional. # Here we won't mark bidirectional/reverse, we will have another rewriter running # after this one, which will based on patterns to combine a forward LSTM and a # backward LSTM into a bidirectional one. direction = "forward" num_direction = 1 # todo: input_forget attr = {"direction": direction, "hidden_size": rnn_props.hidden_size} inputs = rnn_props.onnx_input_ids lstm_inputs = [ inputs["X"], inputs["W"], inputs["R"], inputs["B"], inputs["sequence_lens"], inputs["initial_h"], inputs["initial_c"] ] lstm_node = make_onnx_node(self.g, "LSTM", lstm_inputs, attr, 3) x_shape = self.g.get_shape(lstm_node.input[0]) x_seq_length = x_shape[0] x_batch_size = x_shape[1] out_dtype = self.g.get_dtype(inputs["X"]) self.g.set_shape( lstm_node.output[0], [x_seq_length, num_direction, x_batch_size, rnn_props.hidden_size]) self.g.set_dtype(lstm_node.output[0], out_dtype) self.g.set_shape(lstm_node.output[1], [num_direction, x_batch_size, rnn_props.hidden_size]) self.g.set_dtype(lstm_node.output[1], out_dtype) self.g.copy_shape(lstm_node.output[1], lstm_node.output[2]) self.g.set_dtype(lstm_node.output[2], out_dtype) return lstm_node
def process_seq_length(self, rnn_props, seq_length_node): # output: [time step, batch size, input size] shape_node = make_onnx_node(self.g, "Shape", [rnn_props.x_input_id]) # LSTMCell only allow inputs of [batch size, input_size], so we assume dynamic_rnn has 3 dims. # Slice cannot support Int64 in OPSET 7, so we cast here. attr = {"to": onnx_pb.TensorProto.FLOAT} cast_shape_node = make_onnx_node(self.g, "Cast", [shape_node.output[0]], attr) self.g.copy_shape(shape_node.output[0], cast_shape_node.output[0]) attr = {"axes": [0], "starts": [1], "ends": [2]} batchsize_node = make_onnx_node(self.g, "Slice", [cast_shape_node.output[0]], attr) # Tile's repeats must be INT64 attr = {"to": onnx_pb.TensorProto.INT64} repeat_node = make_onnx_node(self.g, 'Cast', [batchsize_node.output[0]], attr) self.all_nodes.extend([shape_node, cast_shape_node, batchsize_node, repeat_node]) if not seq_length_node: attr = {"axes": [0], "starts": [0], "ends": [1]} timestep_node = make_onnx_node(self.g, 'Slice', [cast_shape_node.output[0]], attr) tile_node = make_onnx_node(self.g, 'Tile', [timestep_node.output[0], repeat_node.output[0]]) attr = {"to": onnx_pb.TensorProto.INT32} # LSTM sequence_lens needs to be int32 seq_length_node = make_onnx_node(self.g, 'Cast', [tile_node.output[0]], attr) self.all_nodes.extend([timestep_node, tile_node, seq_length_node]) rnn_props.onnx_input_ids["sequence_lens"] = seq_length_node.output[0] return seq_length_node, batchsize_node
def _connect_gru_state_to_graph(self, gru_node, exit_node, rnn_props): # in tf, state output shape is: [batch, hidden] # in onnx, output shape is: [number_directions, batch, hidden] output_id = gru_node.output[1] squeeze_node = make_onnx_node(self.g, "Squeeze", [output_id], attr={"axes": [0]}) gru_state_shape = self.g.get_shape(output_id) self.g.set_shape(squeeze_node.output[0], [gru_state_shape[1], gru_state_shape[2]]) self.all_nodes.extend([squeeze_node]) self.g.replace_all_inputs(self.all_nodes, exit_node.output[0], squeeze_node.output[0])
def _create_scan_node(self, context, scan_props): log.debug("create scan node") # here we did not give the sequence_length, because # current batch size is 1, not original batch size # original seq_length will be used by the loop body of Scan op. scan_node = make_onnx_node( self.g, "Scan", [""] + scan_props.initial_state_and_scan_inputs, attr={"num_scan_inputs": len(scan_props.loop_scan_inputs)}, output_count=len(scan_props.loop_state_outputs + scan_props.loop_scan_outputs), skip_conversion=True) # the first state var is time-iterator. index = 0 time_input_shape = self.g.get_shape(scan_node.input[1]) time_input_dtype = self.g.get_dtype(scan_node.input[1]) log.debug( "_create_scan_node - set scan state_output shape for %s[%s]:%s", scan_node.name, index, time_input_shape) self.g.set_shape(scan_node.output[index], time_input_shape) self.g.set_dtype(scan_node.output[index], time_input_dtype) index += 1 # for other state vars state_input_shape = self.g.get_shape(scan_node.input[2]) state_input_dtype = self.g.get_dtype(scan_node.input[2]) for i in range(len(scan_props.loop_state_outputs) - 1): log.debug( "_create_scan_node - set scan state_output shape for %s[%s]:%s", scan_node.name, index, state_input_shape) self.g.set_shape(scan_node.output[index], state_input_shape) self.g.set_dtype(scan_node.output[index], state_input_dtype) index += 1 last_scan_input_shape = self.g.get_shape(scan_node.input[-1]) batch = last_scan_input_shape[0] # should be 1 time = last_scan_input_shape[1] for i in range(len(scan_props.loop_scan_outputs)): scan_out_dtype = self.g.get_dtype(scan_props.loop_scan_outputs[i]) output_shape = self.g.get_shape(scan_props.loop_scan_outputs[i]) scan_output_shape = [batch, time] + output_shape log.debug( "scan output [%s] has shape %s, batch:%s, time: %s, cell output shape: %s", scan_props.loop_scan_outputs[i], scan_output_shape, batch, time, output_shape) log.debug( "_create_scan_node - set scan scan_output shape for %s[%s]:%s", scan_node.name, index, scan_output_shape) self.g.set_shape(scan_node.output[index], scan_output_shape) self.g.set_dtype(scan_node.output[index], scan_out_dtype) index += 1 return scan_node
def slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm, lstm_output_index, all_nodes, to_remove): fw_consumers = g.find_output_consumers(lstm_fw.output[lstm_output_index]) bw_consumers = g.find_output_consumers(lstm_bw.output[lstm_output_index]) if lstm_output_index == 0: axis = 1 # remove reverse op for lstm_bw reverse_nodes = get_reverse_nodes_after_y_output(g, lstm_bw) if not reverse_nodes: raise ValueError( "should not happen y_output is not followed with reverse node") for r_op in reverse_nodes: log.debug("remove reverse op called %s", r_op.name) g.replace_all_inputs(all_nodes, r_op.output[0], r_op.input[0]) to_remove.append(r_op.name) elif lstm_output_index in [1, 2]: axis = 0 else: raise ValueError("LSTM only should has 3 outputs.") if fw_consumers: attr = {"axes": [axis], "starts": [0], "ends": [1]} slice_node_fw = make_onnx_node(g, "Slice", [bi_lstm.output[lstm_output_index]], attr) all_nodes.append(slice_node_fw) g.replace_all_inputs(fw_consumers, lstm_fw.output[lstm_output_index], slice_node_fw.output[0]) if bw_consumers: attr = {"axes": [axis], "starts": [1], "ends": [2]} slice_node_bw = make_onnx_node(g, "Slice", [bi_lstm.output[lstm_output_index]], attr) all_nodes.append(slice_node_bw) g.replace_all_inputs(bw_consumers, lstm_bw.output[lstm_output_index], slice_node_bw.output[0])
def _convert_timemajor_transpose(self, node): if not check_is_timemajor_transpose(node): log.debug("not found timemajor transpose") return None log.debug("found timemajor transpose") attr = {"perm": np.array([1, 0, 2], dtype=np.int64)} new_trans = make_onnx_node(self.g, "Transpose", [node.input[0]], attr) self.g.copy_shape(node.output[0], new_trans.output[0]) self.g.replace_all_inputs(self.g.get_nodes(), node.output[0], new_trans.output[0]) return new_trans
def _connect_lstm_ych_to_graph(self, lstm_node, exit_node, rnn_props): # in tf, concat of y_c and y_h output shape is: [batch, hidden *2] # in onnx, y_c/y_h output shape is: [number_directions, batch, hidden] concat = make_onnx_node(self.g, "Concat", [lstm_node.output[2], lstm_node.output[1]], attr={"axis": 2}) yc_shape = self.g.get_shape(lstm_node.output[2]) self.g.set_shape(concat.output[0], [yc_shape[0], yc_shape[1], yc_shape[2] * 2]) squeeze_node = make_onnx_node(self.g, "Squeeze", [concat.output[0]], attr={"axes": [0]}) concat_shape = self.g.get_shape(concat.output[0]) self.g.set_shape(squeeze_node.output[0], [concat_shape[1], concat_shape[2]]) self.all_nodes.extend([concat, squeeze_node]) self.g.replace_all_inputs(self.all_nodes, exit_node.output[0], squeeze_node.output[0])
def _create_squeeze_node(self, target_name, input_id): squeeze_node = make_onnx_node(self.g, "Squeeze", [input_id], attr={"axes": [0]}, skip_conversion=True, op_name_scope=target_name) input_shape = self.g.get_shape(input_id) if input_shape is None: raise ValueError(input_id + " is none") input_shape = list(input_shape)[1:] self.g.set_shape(squeeze_node.output[0], input_shape) self.g.set_dtype(squeeze_node.output[0], self.g.get_dtype(input_id)) return squeeze_node
def _process_single_init_node(g, fw_init_input_id, bw_init_input_id, to_append): fw_init_is_const, init_fw_val = check_const(g, fw_init_input_id) bw_init_is_const, init_bw_val = check_const(g, bw_init_input_id) if fw_init_is_const and bw_init_is_const: initial_val = np.concatenate((init_fw_val, init_bw_val), axis=0) init_name = utils.make_name("initial") init_node = g.make_const(init_name, initial_val, skip_conversion=True) else: attr = {"axis": 0} init_node = make_onnx_node(g, "Concat", [fw_init_input_id, bw_init_input_id], attr) to_append.append(init_node) return init_node
def _process_init_nodes(self, initializer_input_id, rnn_props): # copy from lstm_rewriter # todo: remove this once Fill ops is supported fill_ch_init_node = self._workaround_fill_ch_init_node(initializer_input_id, rnn_props) if fill_ch_init_node: return fill_ch_init_node.output[0] node = self.g.get_node_by_name(initializer_input_id) if node.is_const(): val = node.get_tensor_value() initial_name = utils.make_name("Const") new_val = np.expand_dims(val, axis=0) const_node = self.g.make_const(initial_name, new_val) return const_node.output[0] squeeze_node = make_onnx_node(self.g, "Unsqueeze", [initializer_input_id], attr={"axes": [0]}) self.g.replace_all_inputs(self.g.get_nodes(), initializer_input_id, squeeze_node.output[0]) self.all_nodes.append(squeeze_node) return squeeze_node.output[0]
def process_bilstm(g, bi_lstms): for fw, bw in bi_lstms: input_id = fw[0] log.debug("=========================") log.debug("start handling potential bidirectional lstm %s", input_id) lstm_fw = fw[1] lstm_bw = bw[1] w_fw = get_np_val_for_const(g, lstm_fw, 1) w_bw = get_np_val_for_const(g, lstm_bw, 1) r_fw = get_np_val_for_const(g, lstm_fw, 2) r_bw = get_np_val_for_const(g, lstm_bw, 2) b_fw = get_np_val_for_const(g, lstm_fw, 3) b_bw = get_np_val_for_const(g, lstm_bw, 3) W = np.concatenate((w_fw, w_bw), axis=0) R = np.concatenate((r_fw, r_bw), axis=0) B = np.concatenate((b_fw, b_bw), axis=0) all_nodes = g.get_nodes() if len(lstm_fw.inputs) == len(lstm_bw.inputs): if len(lstm_fw.inputs) > 4: h_node, c_node = process_ch_init_nodes(g, lstm_fw, lstm_bw, all_nodes) else: log.error("fw, bw lstm inputs num is not consistent. stop") continue # create node w_name = utils.make_name("W") w_node = g.make_const(w_name, W, skip_conversion=True) r_name = utils.make_name("R") r_node = g.make_const(r_name, R, skip_conversion=True) b_name = utils.make_name("B") b_node = g.make_const(b_name, B, skip_conversion=True) lstm_inputs = [ lstm_fw.input[0], w_node.output[0], r_node.output[0], b_node.output[0] ] if len(lstm_fw.inputs) > 4: lstm_inputs.extend( [lstm_fw.input[4], h_node.output[0], c_node.output[0]]) direction = "bidirectional" if lstm_fw.get_attr("hidden_size").i == lstm_bw.get_attr( "hidden_size").i: hidden_size = lstm_fw.get_attr("hidden_size").i else: log.error("fw and bw has different hidden_size, skip") continue attr = {"direction": direction, "hidden_size": hidden_size} bi_lstm_node = make_onnx_node(g, "LSTM", lstm_inputs, attr=attr, output_count=3) all_nodes.append(bi_lstm_node) log.debug("processing output nodes") to_remove = [ lstm_fw.name, lstm_fw.input[1], lstm_fw.input[2], lstm_fw.input[3], lstm_bw.name, lstm_bw.input[1], lstm_bw.input[2], lstm_bw.input[3] ] slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm_node, 0, all_nodes, to_remove) slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm_node, 1, all_nodes, to_remove) slice_bilstm_for_original_lstm_consumers(g, lstm_fw, lstm_bw, bi_lstm_node, 2, all_nodes, to_remove) lstm_bw_old_x = lstm_bw.input[0] new_nodes = [] for n in all_nodes: if n.name not in to_remove: new_nodes.append(n) g.set_nodes(new_nodes) old_x_consumers = g.find_output_consumers(lstm_bw_old_x) # the transpose/reverse here must be followed by LSTM if it is still useful. # this is guaranteed by dynamic_rnn logic. old_x_has_lstm_as_consumer = [ n for n in old_x_consumers if n.type == "LSTM" ] if not old_x_has_lstm_as_consumer: log.debug("plan to remove useless reverse op in bw") reverse_node = g.get_node_by_name(lstm_bw_old_x) if reverse_node.type == "Transpose": reverse_node = reverse_node.inputs[0] g.replace_all_inputs(g.get_nodes(), reverse_node.output[0], reverse_node.input[0]) new_nodes = g.get_nodes() new_nodes.remove(reverse_node) g.set_nodes(new_nodes) else: raise ValueError( "Reverse is still used by LSTM as input, cannot remove") g.update_proto() return g.get_nodes()
def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_output=False): nodes_to_add = [] shape_node = make_onnx_node(self.g, "Shape", [input_id]) nodes_to_add.append(shape_node) inferred_shape = self.g.get_shape(input_id) if handle_output is True: # handle output: # if required dim values don't contain more than one -1, # just use a const for Reshape's shape input. if inferred_shape is not None and inferred_shape[1:].count( -1) <= 1: new_shape_node = self.g.make_const( utils.make_name(target_name + "_target_shape"), np.array(inferred_shape[1:], dtype=np.int64)) else: # otherwise, get the dim dynamically, e.g. remove the fake batch size (e.g.1) # from [1, time, real-batch, ...] origin_shape_node = make_onnx_node( self.g, "Cast", [shape_node.output[0]], {"to": onnx_pb.TensorProto.FLOAT}) nodes_to_add.append(origin_shape_node) sliced_shape_node = make_onnx_node( self.g, "Slice", [origin_shape_node.output[0]], { "axes": [0], "starts": [1], "ends": [sys.maxsize] }) nodes_to_add.append(sliced_shape_node) new_shape_node = make_onnx_node( self.g, "Cast", [sliced_shape_node.output[0]], {"to": onnx_pb.TensorProto.INT64}) nodes_to_add.append(new_shape_node) new_shape = inferred_shape[1:] else: # handle input: if inferred_shape is not None and inferred_shape.count(-1) <= 1: new_shape_node = self.g.make_const( utils.make_name(target_name + "_target_shape"), np.array([1] + inferred_shape, dtype=np.int64)) else: # add a fake batch size : 1 fake_batch_size_node = self.g.make_const( utils.make_name(target_name + "_target_shape"), np.array([ 1, ], dtype=np.int64)) new_shape_node = make_onnx_node( self.g, "Concat", [fake_batch_size_node.output[0], shape_node.output[0]], {"axis": 0}) nodes_to_add.append(new_shape_node) new_shape = [1] + inferred_shape reshape_node = make_onnx_node(self.g, "Reshape", [input_id, new_shape_node.output[0]], skip_conversion=True, op_name_scope=target_name) nodes_to_add.append(reshape_node) self.g.set_shape(reshape_node.output[0], new_shape) self.g.set_dtype(reshape_node.output[0], self.g.get_dtype(input_id)) log.debug("create Reshape for scan output %s, with output shape %s", reshape_node.output[0], new_shape) return nodes_to_add
def process_bigru(g, bi_grus): for fw, bw in bi_grus: input_id = fw[0] log.debug("=========================") log.debug("start handling potential bidirectional gru %s", input_id) gru_fw = fw[1] gru_bw = bw[1] w_fw = get_np_val_for_const(g, gru_fw, 1) w_bw = get_np_val_for_const(g, gru_bw, 1) r_fw = get_np_val_for_const(g, gru_fw, 2) r_bw = get_np_val_for_const(g, gru_bw, 2) b_fw = get_np_val_for_const(g, gru_fw, 3) b_bw = get_np_val_for_const(g, gru_bw, 3) W = np.concatenate((w_fw, w_bw), axis=0) R = np.concatenate((r_fw, r_bw), axis=0) B = np.concatenate((b_fw, b_bw), axis=0) all_nodes = g.get_nodes() if len(gru_fw.inputs) == len(gru_bw.inputs): if len(gru_fw.inputs) > 4: initializer_node = process_init_nodes(g, gru_fw, gru_bw, all_nodes) else: log.error("fw, bw gru inputs num is not consistent. stop") continue # create node w_name = utils.make_name("W") w_node = g.make_const(w_name, W, skip_conversion=True) r_name = utils.make_name("R") r_node = g.make_const(r_name, R, skip_conversion=True) b_name = utils.make_name("B") b_node = g.make_const(b_name, B, skip_conversion=True) gru_inputs = [ gru_fw.input[0], w_node.output[0], r_node.output[0], b_node.output[0] ] if len(gru_fw.inputs) > 4: gru_inputs.extend([gru_fw.input[4], initializer_node.output[0]]) direction = "bidirectional" if gru_fw.get_attr("hidden_size").i == gru_bw.get_attr( "hidden_size").i: hidden_size = gru_fw.get_attr("hidden_size").i else: log.error("fw and bw has different hidden_size, skip") continue # activation has to be took care # attr here is proto, and make_onnx_node needs dict activations = [ act.decode("utf-8") for act in gru_fw.get_attr("activations").strings ] activations += [ act.decode("utf-8") for act in gru_bw.get_attr("activations").strings ] attr = { "direction": direction, "hidden_size": hidden_size, "activations": activations } bi_gru_node = make_onnx_node(g, "GRU", gru_inputs, attr=attr, output_count=2) all_nodes.append(bi_gru_node) log.debug("processing output nodes") to_remove = [ gru_fw.name, gru_fw.input[1], gru_fw.input[2], gru_fw.input[3], gru_bw.name, gru_bw.input[1], gru_bw.input[2], gru_bw.input[3] ] slice_bilstm_for_original_lstm_consumers(g, gru_fw, gru_bw, bi_gru_node, 0, all_nodes, to_remove) slice_bilstm_for_original_lstm_consumers(g, gru_fw, gru_bw, bi_gru_node, 1, all_nodes, to_remove) gru_bw_old_x = gru_bw.input[0] new_nodes = [] for n in all_nodes: if n.name not in to_remove: new_nodes.append(n) g.set_nodes(new_nodes) old_x_consumers = g.find_output_consumers(gru_bw_old_x) # the transpose/reverse here must be followed by GRU if it is still useful. # this is guaranteed by dynamic_rnn logic. old_x_has_gru_as_consumer = [ n for n in old_x_consumers if n.type == "GRU" ] if not old_x_has_gru_as_consumer: log.debug("plan to remove useless reverse op in bw") reverse_node = g.get_node_by_name(gru_bw_old_x) if reverse_node.type == "Transpose": reverse_node = reverse_node.inputs[0] g.replace_all_inputs(g.get_nodes(), reverse_node.output[0], reverse_node.input[0]) new_nodes = g.get_nodes() new_nodes.remove(reverse_node) g.set_nodes(new_nodes) else: raise ValueError( "Reverse is still used by GRU as input, cannot remove") g.update_proto() return g.get_nodes()