def _create_loop_node(self, context, loop_props): loop_outputs = [] loop_output_shapes = [] loop_output_dtypes = [] for tensor_value_info in loop_props.state_outputs_exits + loop_props.scan_outputs_exits: if tensor_value_info.id: loop_outputs.append(tensor_value_info.id) loop_output_shapes.append(tensor_value_info.shape) loop_output_dtypes.append(tensor_value_info.dtype) n = self.g.get_node_by_output(tensor_value_info.id) self.g.remove_node(n.name) else: loop_outputs.append(utils.make_name("unused_loop_output_")) loop_output_shapes.append([-1]) loop_output_dtypes.append(None) # trip count and cond are not used, giving them values just because bug # (https://github.com/Microsoft/onnxruntime/issues/255) of onnxruntime. trip_cnt = self.g.make_const(utils.make_name("trip_count"), np.array(sys.maxsize, dtype=np.int64)) cond = self.g.make_const(utils.make_name("cond"), np.array(True, dtype=np.bool)) loop_node = self.g.make_node( "Loop", [trip_cnt.output[0]] + [cond.output[0]] + loop_props. state_inputs_initial_values, # ONNX Loop support state inputs only outputs=loop_outputs, op_name_scope="generic_loop", shapes=loop_output_shapes, dtypes=loop_output_dtypes, skip_conversion=False) return loop_node
def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_output=False): nodes_to_add = [] shape_node = self.g.make_node("Shape", [input_id]) nodes_to_add.append(shape_node) inferred_shape = self.g.get_shape(input_id) if handle_output is True: # handle output: # if required dim values don't contain more than one -1, # just use a const for Reshape's shape input. if inferred_shape is not None and inferred_shape[1:].count(-1) <= 1: new_shape_node = self.g.make_const(utils.make_name(target_name + "_target_shape"), np.array(inferred_shape[1:], dtype=np.int64)) else: # otherwise, get the dim dynamically, e.g. remove the fake batch size (e.g.1) # from [1, time, real-batch, ...] origin_shape_node = self.g.make_node("Cast", [shape_node.output[0]], {"to": onnx_pb.TensorProto.FLOAT}) nodes_to_add.append(origin_shape_node) sliced_shape_node = self.g.make_node("Slice", [origin_shape_node.output[0]], {"axes": [0], "starts": [1], "ends": [sys.maxsize]}) nodes_to_add.append(sliced_shape_node) new_shape_node = self.g.make_node("Cast", [sliced_shape_node.output[0]], {"to": onnx_pb.TensorProto.INT64}) nodes_to_add.append(new_shape_node) new_shape = inferred_shape[1:] else: # handle input: if inferred_shape is not None and inferred_shape.count(-1) <= 1: new_shape_node = self.g.make_const(utils.make_name(target_name + "_target_shape"), np.array([1] + inferred_shape, dtype=np.int64)) else: # add a fake batch size : 1 fake_batch_size_node = self.g.make_const(utils.make_name(target_name + "_target_shape"), np.array([1,], dtype=np.int64)) new_shape_node = self.g.make_node("Concat", [fake_batch_size_node.output[0], shape_node.output[0]], attr={"axis": 0}) nodes_to_add.append(new_shape_node) new_shape = [1] + inferred_shape reshape_node = self.g.make_node("Reshape", [input_id, new_shape_node.output[0]], shapes=[new_shape], dtypes=[self.g.get_dtype(input_id)], op_name_scope=target_name) nodes_to_add.append(reshape_node) log.debug("create Reshape for scan output %s, with output shape %s", reshape_node.output[0], new_shape) return nodes_to_add
def _create_loop_node(self, context, loop_props): # reuse original output connection id (e.g. Exit_XXX), so we don't need set shape. loop_outputs = [] for tensor_value_info in loop_props.state_outputs_exits + loop_props.scan_outputs_exits: if tensor_value_info.id: loop_outputs.append(tensor_value_info.id) else: loop_outputs.append(utils.make_name("unused_loop_output_")) # trip count and cond are not used, giving them values just because bug # (https://github.com/Microsoft/onnxruntime/issues/255) of onnxruntime. trip_cnt = self.g.make_const(utils.make_name("trip_count"), np.array(sys.maxsize, dtype=np.int64)) cond = self.g.make_const(utils.make_name("cond"), np.array(True, dtype=np.bool)) loop_node = self.g.make_node("Loop", [trip_cnt.output[0]] + [cond.output[0]] + loop_props.state_inputs_initial_values, # ONNX Loop support state inputs only outputs=loop_outputs, op_name_scope="generic_loop", skip_conversion=False) return loop_node
def rewrite(self): log.debug("enter custom rnn late rewriter") nodes = self.g.get_nodes() nodes_to_remove = [] for scan_node in nodes: if scan_node.type != "Scan": continue log.debug("late write for scan node %s", scan_node.name) num_scan_inputs = scan_node.get_attr("num_scan_inputs").i if not BodyGraphDict.has_body_graph_info(scan_node.name): continue body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name) onnx_nodes, _ = LoopRewriterBase.find_subgraph( body_graph_meta, self.g) nodes_to_remove.extend(onnx_nodes) log.debug("start creating body graph for scan node %s ", scan_node.name) body_graph_initializers = {} const_nodes = [ n for n in onnx_nodes if n.type in ("Const", "ConstV2") ] for n in const_nodes: # when set nodes, Const should be removed, they need be replaced as initializers. body_graph_initializers[n.output[0]] = self.g.initializers[ n.output[0]] onnx_nodes.remove(n) onnx_nodes = set(onnx_nodes) ops = [] for op in onnx_nodes: onnx_op = op.op ops.append(onnx_op) body_g = Graph(ops, output_shapes=self.g._output_shapes, dtypes=self.g._dtypes) body_g._initializers = body_graph_initializers log.debug("start preparing body graph inputs nodes") temp_nodes = body_g.get_nodes() i = 0 input_count = len(body_graph_meta.input_ids) for input_name, init_input_id in zip( body_graph_meta.input_ids, body_graph_meta.initial_input_ids): shape = body_g.get_shape(input_name) dtype = body_g.get_dtype(input_name) if shape is None: shape = self.g.get_shape(init_input_id) if i >= input_count - num_scan_inputs: loop_input_shape = list(shape)[2:] # delete [1, time,] else: loop_input_shape = list(shape) else: loop_input_shape = list(shape) onnx_input_shape = utils.make_onnx_shape(loop_input_shape) val = helper.make_tensor_value_info(input_name, dtype, onnx_input_shape) body_g.add_model_input(input_name, val) i += 1 log.debug("start preparing body graph outputs nodes") new_output_names = [] for o in body_graph_meta.output_ids: # insert identity node, since sometimes we need output same output_id as state_output # and scan_out, but ONNX don't allow the same output_id appeared more than once as # output node. identity_name = utils.make_name("Identity") identity_output = utils.port_name(identity_name) node = Node( helper.make_node("Identity", [o], [identity_output], name=identity_name), body_g) body_g.set_dtype(identity_output, body_g.get_dtype(o)) body_g.copy_shape(o, identity_output) new_output_names.append(identity_output) temp_nodes.append(node) body_g.set_nodes(temp_nodes) body_g.topological_sort(body_g.get_nodes()) log.debug("start make graph based on body graph nodes") body_g.output_names = new_output_names graph = body_g.make_graph("scan body graph") scan_node.set_attr("body", graph) # remove nodes in body graph from g for n in set(nodes_to_remove): if n in nodes: nodes.remove(n) elif self.g.is_initializer(n.output[0]): del self.g.initializers[n.output[0]] else: raise ValueError("error when removing nodes") return nodes
def rewrite(self, context): log.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # todo(pengwa): we don't check the case where loop body won't be executed at all. ## create Loop body graph with existing nodes # replace condition graph's inputs to be cell graph's outputs, because we want condition graph # to consumer cell graph outputs. for loop_var in cond_g_info.dependent_vars: self.g.replace_all_inputs( cond_g_info.nodes, loop_var.switch_true_identity_output.id, loop_var.next_iteration_input.id) body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: out_tensor_value_info.shape = utils.create_vague_shape_like( out_tensor_value_info.shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input( input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]}) gather_node = loop_body_g.make_node( "Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]}) loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0]) ## create Loop node loop_node = self._create_loop_node(context, loop_props) if not loop_node: log.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL loop_node.set_body_graph_as_attr("body", loop_body_g) log.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() log.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL