def rewrite(self, context): logger.debug("enter rewrite function") try: scan_props = context.loop_properties state_inputs_initial_values = [] for state_input in scan_props.state_inputs_initial_values: if self.g.opset == 8: nodes = self._adapt_scan_sequence_input_or_output( "input", state_input, False) state_inputs_initial_values.append(nodes[-1].output[0]) else: # since opset 9 state_inputs_initial_values.append(state_input) scan_inputs_initial_values = [] scan_length = -1 for scan_input in scan_props.scan_inputs_initial_values: if self.g.opset == 8: nodes = self._adapt_scan_sequence_input_or_output( "input", scan_input, False) scan_inputs_initial_values.append(nodes[-1].output[0]) else: # since opset 9 scan_inputs_initial_values.append(scan_input) scan_shape = self.g.get_shape(scan_input) if scan_shape is not None and len(scan_shape) > 0: scan_length = scan_shape[0] cell_g_info = context.cell_graph scan_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, cell_g_info.nodes, cell_g_info.outputs) for input_tensor_info in scan_props.state_inputs: scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape) for input_tensor_info in scan_props.scan_inputs: scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape) scan_node = self._create_scan_node( context, scan_props, state_inputs_initial_values + scan_inputs_initial_values, scan_body_g, scan_length) if not scan_node: logger.error("failed to create scan node during rewrite") return REWRITER_RESULT.FAIL self._connect_scan_with_output(context, scan_node) return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error( "custom rnn rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL
def rewrite(self, context): log.debug("enter rewrite function") try: scan_props = context.loop_properties nodes_to_append = [] state_inputs_initial_values = [] for state_input in scan_props.state_inputs_initial_values: nodes = self._adapt_scan_sequence_input_or_output( "input", state_input, False) state_inputs_initial_values.append(nodes[-1].output[0]) nodes_to_append.extend(nodes) scan_inputs_initial_values = [] for scan_input in scan_props.scan_inputs_initial_values: nodes = self._adapt_scan_sequence_input_or_output( "input", scan_input, False) scan_inputs_initial_values.append(nodes[-1].output[0]) nodes_to_append.extend(nodes) cell_g_info = context.cell_graph scan_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, cell_g_info.nodes, cell_g_info.outputs) for input_tensor_info in scan_props.state_inputs: scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape) for input_tensor_info in scan_props.scan_inputs: scan_body_g.add_graph_input(input_tensor_info.id, input_tensor_info.dtype, input_tensor_info.shape) scan_node = self._create_scan_node( context, scan_props, state_inputs_initial_values + scan_inputs_initial_values) if not scan_node: log.error("failed to create scan node during rewrite") return REWRITER_RESULT.FAIL scan_node.set_body_graph_as_attr("body", scan_body_g) nodes_to_append.append(scan_node) to_append = self._connect_scan_with_output(context, scan_node) nodes_to_append.extend(to_append) all_nodes = self.g.get_nodes() all_nodes.extend(nodes_to_append) self.g.set_nodes(all_nodes) return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() log.error( "custom rnn rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # create a dummy loop to calculate the init condition init_cond_output = self._create_subgraph_initial_cond(cond_g_info) ## create Loop body graph with existing nodes body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: shape = out_tensor_value_info.shape utils.make_sure( shape is not None, "Conversion of Loop requries output shape [{}] exists". format(out_tensor_value_info.id)) out_tensor_value_info.shape = utils.create_vague_shape_like( shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input( input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]}) gather_node = loop_body_g.make_node( "Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]}) loop_body_g.replace_all_inputs( input_ta.consumer.id, data_node.output[0]) # ops=loop_body_g.get_nodes() ## create Loop node branches = {"body": loop_body_g} loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error( "loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL
def rewrite(self, context): logger.debug("enter rewrite function") loop_node = None try: loop_props = context.loop_properties cell_g_info = context.cell_graph cond_g_info = context.cond_graph # todo(pengwa): we don't check the case where loop body won't be executed at all. ## create Loop body graph with existing nodes # replace condition graph's inputs to be cell graph's outputs, because we want condition graph # to consumer cell graph outputs. for loop_var in cond_g_info.dependent_vars: self.g.replace_all_inputs( cond_g_info.nodes, loop_var.switch_true_identity_output.id, loop_var.next_iteration_input.id) body_nodes = set(cell_g_info.nodes + cond_g_info.nodes) body_outputs = cond_g_info.outputs + cell_g_info.outputs for out_tensor_value_info in body_outputs: out_tensor_value_info.shape = utils.create_vague_shape_like( out_tensor_value_info.shape) loop_body_g = LoopRewriterBase.construct_graph_from_nodes( self.g, body_nodes, body_outputs) # create loop body graph inputs loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ()) loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ()) for i, tensor_value_info in enumerate(loop_props.state_inputs): input_name = tensor_value_info.id if input_name is None: # if the variable is not used in the body graph, then we created a fake one, # the same type and shape as its corresponding output. out_tensor_value_info = loop_props.state_outputs[i] dtype = out_tensor_value_info.dtype shape = out_tensor_value_info.shape input_name = utils.make_name("unused_state_input_") else: dtype = tensor_value_info.dtype shape = tensor_value_info.shape loop_body_g.add_graph_input( input_name, dtype, utils.create_vague_shape_like(shape)) for input_ta in loop_props.tensor_array_inputs: # Loop does not have scan inputs, so we use Gather to get data for each iteration. index_node = loop_body_g.make_node("Unsqueeze", [input_ta.index_input_id], attr={"axes": [0]}) gather_node = loop_body_g.make_node( "Gather", [input_ta.data_input_id, index_node.output[0]]) data_node = loop_body_g.make_node("Squeeze", [gather_node.output[0]], attr={"axes": [0]}) loop_body_g.replace_all_inputs(loop_body_g.get_nodes(), input_ta.consumer.id, data_node.output[0]) ## create Loop node loop_node = self._create_loop_node(context, loop_props) if not loop_node: logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL loop_node.set_body_graph_as_attr("body", loop_body_g) logger.debug("rewrite successfully") return REWRITER_RESULT.OK except Exception as ex: tb = traceback.format_exc() logger.error( "loop rewrite failed, due to exception: %s, details:%s", ex, tb) return REWRITER_RESULT.FAIL