def rewrite(self, context): logger.debug("enter rewrite function") try: scan_props = context.loop_properties state_inputs_initial_values = [] for state_input in scan_props.state_inputs_initial_values: if self.g.opset == 8: nodes = self._adapt_scan_sequence_input_or_output( "input", state_input, False) state_inputs_initial_values.append(nodes[-1].output[0]) else: # since opset 9 state_inputs_initial_values.append(state_input) scan_inputs_initial_values = [] scan_length = -1 for scan_input in scan_props.scan_inputs_initial_values: if self.g.opset == 8: nodes = self._adapt_scan_sequence_input_or_output( "input", scan_input, False) scan_inputs_initial_values.append(nodes[-1].output[0]) else: # since opset 9 scan_inputs_initial_values.append(scan_input) scan_shape = self.g.get_shape(scan_input) if scan_shape is not None and len(scan_shape) > 0: scan_length = scan_shape[0] cell_g_info = context.cell_graph scan_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, cell_g_info.nodes, cell_g_info.outputs) for input_tensor_info in scan_props.state_inputs: scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape) for input_tensor_info in scan_props.scan_inputs: scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape) scan_node = self._create_scan_node( context, scan_props, state_inputs_initial_values + scan_inputs_initial_values, scan_body_g, scan_length) if not scan_node: logger.error("failed to create scan node during rewrite") return REWRITER_RESULT.FAIL self._connect_scan_with_output(context, scan_node) return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error( "custom rnn rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL
def rewrite(self, context): log.debug("enter rewrite function") try: scan_props = context.loop_properties nodes_to_append = [] state_inputs_initial_values = [] for state_input in scan_props.state_inputs_initial_values: nodes = self._adapt_scan_sequence_input_or_output( "input", state_input, False) state_inputs_initial_values.append(nodes[-1].output[0]) nodes_to_append.extend(nodes) scan_inputs_initial_values = [] for scan_input in scan_props.scan_inputs_initial_values: nodes = self._adapt_scan_sequence_input_or_output( "input", scan_input, False) scan_inputs_initial_values.append(nodes[-1].output[0]) nodes_to_append.extend(nodes) cell_g_info = context.cell_graph scan_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, cell_g_info.nodes, cell_g_info.outputs) for input_tensor_info in scan_props.state_inputs: scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape) for input_tensor_info in scan_props.scan_inputs: scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape) scan_node = self._create_scan_node( context, scan_props, state_inputs_initial_values + scan_inputs_initial_values) if not scan_node: log.error("failed to create scan node during rewrite") return REWRITER_RESULT.FAIL scan_node.set_body_graph_as_attr("body", scan_body_g) nodes_to_append.append(scan_node) to_append = self._connect_scan_with_output(context, scan_node) nodes_to_append.extend(to_append) all_nodes = self.g.get_nodes() all_nodes.extend(nodes_to_append) self.g.set_nodes(all_nodes) return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() log.error( "custom rnn rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL
def _match_cell(self, context, unittype): """match unit cell""" for cell_pattern in get_pattern(unittype): matcher = GraphMatcher(cell_pattern, allow_reorder=True) loop_props = context.loop_properties inputs = loop_props.state_inputs + loop_props.scan_inputs input_ids = [input_tensor_value_info.id for input_tensor_value_info in inputs] outputs = loop_props.state_outputs + loop_props.scan_outputs output_ids = [out_tensor_value_info.id for out_tensor_value_info in outputs] body_graph_ops, _, _ = LoopRewriterBase.find_subgraph( set(input_ids), set(output_ids), self.g, merge_as_end=True ) match_results = list(matcher.match_ops(body_graph_ops)) if len(match_results) == 1: return match_results[0] return None
def rewrite(self): log.debug("enter custom rnn late rewriter") nodes = self.g.get_nodes() nodes_to_remove = [] for scan_node in nodes: if scan_node.type != "Scan": continue log.debug("late write for scan node %s", scan_node.name) num_scan_inputs = scan_node.get_attr("num_scan_inputs").i if not BodyGraphDict.has_body_graph_info(scan_node.name): continue body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name) onnx_nodes, _ = LoopRewriterBase.find_subgraph( body_graph_meta, self.g) nodes_to_remove.extend(onnx_nodes) log.debug("start creating body graph for scan node %s ", scan_node.name) body_graph_initializers = {} const_nodes = [ n for n in onnx_nodes if n.type in ("Const", "ConstV2") ] for n in const_nodes: # when set nodes, Const should be removed, they need be replaced as initializers. body_graph_initializers[n.output[0]] = self.g.initializers[ n.output[0]] onnx_nodes.remove(n) onnx_nodes = set(onnx_nodes) ops = [] for op in onnx_nodes: onnx_op = op.op ops.append(onnx_op) body_g = Graph(ops, output_shapes=self.g._output_shapes, dtypes=self.g._dtypes) body_g._initializers = body_graph_initializers log.debug("start preparing body graph inputs nodes") temp_nodes = body_g.get_nodes() i = 0 input_count = len(body_graph_meta.input_ids) for input_name, init_input_id in zip( body_graph_meta.input_ids, body_graph_meta.initial_input_ids): shape = body_g.get_shape(input_name) dtype = body_g.get_dtype(input_name) if shape is None: shape = self.g.get_shape(init_input_id) if i >= input_count - num_scan_inputs: loop_input_shape = list(shape)[2:] # delete [1, time,] else: loop_input_shape = list(shape) else: loop_input_shape = list(shape) onnx_input_shape = utils.make_onnx_shape(loop_input_shape) val = helper.make_tensor_value_info(input_name, dtype, onnx_input_shape) body_g.add_model_input(input_name, val) i += 1 log.debug("start preparing body graph outputs nodes") new_output_names = [] for o in body_graph_meta.output_ids: # insert identity node, since sometimes we need output same output_id as state_output # and scan_out, but ONNX don't allow the same output_id appeared more than once as # output node. identity_name = utils.make_name("Identity") identity_output = utils.port_name(identity_name) node = Node( helper.make_node("Identity", [o], [identity_output], name=identity_name), body_g) body_g.set_dtype(identity_output, body_g.get_dtype(o)) body_g.copy_shape(o, identity_output) new_output_names.append(identity_output) temp_nodes.append(node) body_g.set_nodes(temp_nodes) body_g.topological_sort(body_g.get_nodes()) log.debug("start make graph based on body graph nodes") body_g.output_names = new_output_names graph = body_g.make_graph("scan body graph") scan_node.set_attr("body", graph) # remove nodes in body graph from g for n in set(nodes_to_remove): if n in nodes: nodes.remove(n) elif self.g.is_initializer(n.output[0]): del self.g.initializers[n.output[0]] else: raise ValueError("error when removing nodes") return nodes
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # create a dummy loop to calculate the init condition init_cond_output = self._create_subgraph_initial_cond(cond_g_info) ## create Loop body graph with existing nodes body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: shape = out_tensor_value_info.shape utils.make_sure( shape is not None, "Conversion of Loop requries output shape [{}] exists". format(out_tensor_value_info.id)) out_tensor_value_info.shape = utils.create_vague_shape_like( shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input( input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]}) gather_node = loop_body_g.make_node( "Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]}) loop_body_g.replace_all_inputs( input_ta.consumer.id, data_node.output[0]) # ops=loop_body_g.get_nodes() ## create Loop node branches = {"body": loop_body_g} loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error( "loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # todo(pengwa): we don't check the case where loop body won't be executed at all. ## create Loop body graph with existing nodes # replace condition graph's inputs to be cell graph's outputs, because we want condition graph # to consumer cell graph outputs. for loop_var in cond_g_info.dependent_vars: self.g.replace_all_inputs( cond_g_info.nodes, loop_var.switch_true_identity_output.id, loop_var.next_iteration_input.id) body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: out_tensor_value_info.shape = utils.create_vague_shape_like( out_tensor_value_info.shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input( input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]}) gather_node = loop_body_g.make_node( "Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]}) loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0]) ## create Loop node loop_node = self._create_loop_node(context, loop_props) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL loop_node.set_body_graph_as_attr("body", loop_body_g) logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error( "loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL