def _connect_lstm_ych_to_graph(self, context, i): # in tf, concat of y_c and y_h output shape is: [batch, hidden *2] # in onnx, y_c/y_h output shape is: [number_directions, batch, hidden] gb = GraphBuilder(self.g) exit_output = context.state_variables["ct_ht" + str(i)].exit_output lstm_node = context.rnn_node[i] yc_shape = self.g.get_shape(lstm_node.output[2]) concat_output_shape = [yc_shape[0], yc_shape[1], yc_shape[2] * 2] concat = self.g.make_node( "Concat", [lstm_node.output[2], lstm_node.output[1]], attr={"axis": 2}, shapes=[concat_output_shape], dtypes=[self.g.get_dtype(lstm_node.output[2])]) squeeze_output_shape = [concat_output_shape[1], concat_output_shape[2]] squeeze_node = gb.make_squeeze( { 'data': concat.output[0], "axes": [0] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(concat.output[0])], return_node=True) self.g.replace_all_inputs( exit_output.id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def connect_unit_rnn_output_to_graph(self, context): outputs = context.loop_properties.scan_outputs_exits if not outputs: logger.debug("no one consume output") return gather_output_id = outputs[0].id logger.debug("found output for rnn: %s", gather_output_id) # in tf batch major mode, output shape is : [batch, time, hidden] # in time major mode, output shape is: [time, batch, hidden] # in onnx, output shape is : [time, num_directions, batch, hidden] rnn_node = context.rnn_node output_id = rnn_node.output[0] rnn_output_shape = self.g.get_shape(output_id) squeeze_output_shape = [ rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3] ] gb = GraphBuilder(self.g) squeeze_node = gb.make_squeeze({ 'data': output_id, "axes": [1] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(output_id)], return_node=True) self.g.replace_all_inputs( gather_output_id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def create_rnn_node(self, context): gb = GraphBuilder(self.g) rnn_nodes = list() outputs = context.loop_properties.scan_outputs_exits logger.debug("number of rnn node outputs: %s", len(outputs)) for i in range(self.num_lstm_layers): logger.debug("creating rnn node for layer: %s", i) rnn_nodes.append(self.create_single_rnn_node(context, i)) output_id = rnn_nodes[i].output[0] rnn_output_shape = self.g.get_shape(output_id) squeeze_output_shape = [ rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3] ] squeeze_node = gb.make_squeeze( { "data": output_id, "axes": [1] }, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(output_id)], return_node=True) if i + 1 < self.num_lstm_layers: logger.debug("setting input for layer: %s", i + 1) context.onnx_input_ids[i + 1]["X"] = squeeze_node.output[0] return rnn_nodes
def version_13(cls, ctx, node, **kwargs): ctx.ta_reads.append(node.input[0]) node.type = "Gather" ctx.replace_inputs(node, [node.input[0], node.input[1]]) g = GraphBuilder(ctx) usq_node = g.make_unsqueeze({"axes": [0], 'name': node.child_name(), 'data': node.input[1]}, return_node=True) ctx.insert_node_on_output(usq_node) sq_node = g.make_squeeze({"axes": [0], 'name': node.child_name(), 'data': node.output[0]}, return_node=True) ctx.insert_node_on_output(sq_node)
def _connect_lstm_yc_to_graph(self, context, i): # in tf, y_c output shape is: [batch, hidden] # in onnx, output shape is: [number_directions, batch, hidden] gb = GraphBuilder(self.g) exit_output = context.state_variables["ct" + str(i)].exit_output output_id = context.rnn_node[i].output[2] lstm_yc_shape = self.g.get_shape(output_id) squeeze_node = gb.make_squeeze( { "data": output_id, "axes": [0] }, shapes=[[lstm_yc_shape[1], lstm_yc_shape[2]]], dtypes=[self.g.get_dtype(output_id)], return_node=True) self.g.replace_all_inputs( exit_output.id, squeeze_node.output[0]) # ops=self.g.get_nodes()
def version_10(cls, ctx, node, **kwargs): x = node.input[0] x_shape = ctx.get_shape(x) h = node.input[1] h_shape = ctx.get_shape(h) p = node.input[3] utils.make_sure(node.attr["rnn_mode"].s == b"gru", "rnn mode other than gru are not supported yet") utils.make_sure(node.attr["dropout"].f == 0, "dropout not supported yet") utils.make_sure(node.attr["input_mode"].s == b"linear_input", "input mode must be linear input") num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2 num_layers = int(h_shape[0] / num_dirs) num_units = hidden_size = h_shape[2] input_size = x_shape[2] w_shape = [num_layers * num_dirs, 3 * hidden_size, input_size] w_shape_const = ctx.make_const(utils.make_name("w_shape"), np.array(w_shape, dtype=np.int64)) r_shape = [num_layers * num_dirs, 3 * hidden_size, hidden_size] r_shape_const = ctx.make_const(utils.make_name("r_shape"), np.array(r_shape, dtype=np.int64)) b_shape = [num_layers * num_dirs, 6 * hidden_size] b_shape_const = ctx.make_const(utils.make_name("b_shape"), np.array(b_shape, dtype=np.int64)) zero_const = ctx.make_const(utils.make_name("zero"), np.array([0], dtype=np.int64)) w_end = np.prod(w_shape) w_end_const = ctx.make_const(utils.make_name("w_end"), np.array([w_end], dtype=np.int64)) r_end = w_end + np.prod(r_shape) r_end_const = ctx.make_const(utils.make_name("r_end"), np.array([r_end], dtype=np.int64)) b_end = r_end + np.prod(b_shape) b_end_const = ctx.make_const(utils.make_name("b_end"), np.array([b_end], dtype=np.int64)) def name(nm): return node.name + "_" + nm ws = [name('W_' + str(i)) for i in range(num_layers * num_dirs)] rs = [name('R_' + str(i)) for i in range(num_layers * num_dirs)] bs = [name('B_' + str(i)) for i in range(num_layers * num_dirs)] hs = [name('H_' + str(i)) for i in range(num_layers * num_dirs)] yhs = [name('YH_' + str(i)) for i in range(num_layers * num_dirs)] w_flattened = ctx.make_node( 'Slice', [p, zero_const.output[0], w_end_const.output[0]]) r_flattened = ctx.make_node( 'Slice', [p, w_end_const.output[0], r_end_const.output[0]]) b_flattened = ctx.make_node( 'Slice', [p, r_end_const.output[0], b_end_const.output[0]]) w = utils.make_name('W') r = utils.make_name('R') b = utils.make_name('B') ctx.make_node('Reshape', [w_flattened.output[0], w_shape_const.output[0]], outputs=[w]) ctx.make_node('Reshape', [r_flattened.output[0], r_shape_const.output[0]], outputs=[r]) ctx.make_node('Reshape', [b_flattened.output[0], b_shape_const.output[0]], outputs=[b]) ctx.make_node('Split', [w], outputs=ws) ctx.make_node('Split', [r], outputs=rs) ctx.make_node('Split', [b], outputs=bs) ctx.make_node('Split', [h], outputs=hs) builder = GraphBuilder(ctx) xnf = xnb = x for i in range(num_layers): suffix = '_' + str(i * num_dirs) ctx.make_node('GRU', [ xnf, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix) ], outputs=[name('Y' + suffix), name('YH' + suffix)], attr={ 'direction': 'forward', 'hidden_size': num_units }) xnf = name(x + suffix) builder.make_squeeze({ 'data': name('Y' + suffix), 'outputs': [xnf], 'axes': [1] }) if num_dirs == 2: suffix = '_' + str(i * 2 + 1) ctx.make_node( 'GRU', [ xnb, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix) ], outputs=[name('Y' + suffix), name('YH' + suffix)], attr={ 'direction': 'reverse', 'hidden_size': num_units }) xnb = name(x + suffix) builder.make_squeeze({ 'data': name('Y' + suffix), 'outputs': [xnb], 'axes': [1] }) ctx.remove_node(node.name) if num_dirs == 2: ctx.make_node('Concat', [xnf, xnb], outputs=[node.output[0]], attr={'axis': -1}) else: ctx.make_node('Identity', [xnf], outputs=[node.output[0]]) ctx.make_node('Concat', yhs, outputs=[node.output[1]], attr={'axis': 0})
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # create a dummy loop to calculate the init condition init_cond_output = self._create_subgraph_initial_cond(cond_g_info) ## create Loop body graph with existing nodes body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: shape = out_tensor_value_info.shape utils.make_sure( shape is not None, "Conversion of Loop requries output shape [{}] exists".format(out_tensor_value_info.id) ) out_tensor_value_info.shape = utils.create_vague_shape_like(shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input(input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. gb = GraphBuilder(loop_body_g) index_node = gb.make_unsqueeze({'data': input_ta.index_input_id, "axes": [0]}, return_node=True) gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = gb.make_squeeze({'data': gather_node.output[0], "axes": [0]}, return_node=True) loop_body_g.replace_all_inputs(input_ta.consumer.id, data_node.output[0]) # ops=loop_body_g.get_nodes() ## create Loop node branches = {"body": loop_body_g} loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL