def rewrite(self, context):
        logger.debug("enter rewrite function")
        try:
            scan_props = context.loop_properties

            state_inputs_initial_values = []
            for state_input in scan_props.state_inputs_initial_values:
                if self.g.opset == 8:
                    nodes = self._adapt_scan_sequence_input_or_output(
                        "input", state_input, False)
                    state_inputs_initial_values.append(nodes[-1].output[0])
                else:  # since opset 9
                    state_inputs_initial_values.append(state_input)

            scan_inputs_initial_values = []
            scan_length = -1
            for scan_input in scan_props.scan_inputs_initial_values:
                if self.g.opset == 8:
                    nodes = self._adapt_scan_sequence_input_or_output(
                        "input", scan_input, False)
                    scan_inputs_initial_values.append(nodes[-1].output[0])
                else:  # since opset 9
                    scan_inputs_initial_values.append(scan_input)
                scan_shape = self.g.get_shape(scan_input)
                if scan_shape is not None and len(scan_shape) > 0:
                    scan_length = scan_shape[0]

            cell_g_info = context.cell_graph
            scan_body_g = LoopRewriterBase.construct_graph_from_nodes(
                self.g, cell_g_info.nodes, cell_g_info.outputs)
            for input_tensor_info in scan_props.state_inputs:
                scan_body_g.add_graph_input(input_tensor_info.id,
                                            input_tensor_info.dtype,
                                            input_tensor_info.shape)

            for input_tensor_info in scan_props.scan_inputs:
                scan_body_g.add_graph_input(input_tensor_info.id,
                                            input_tensor_info.dtype,
                                            input_tensor_info.shape)

            scan_node = self._create_scan_node(
                context, scan_props,
                state_inputs_initial_values + scan_inputs_initial_values,
                scan_body_g, scan_length)
            if not scan_node:
                logger.error("failed to create scan node during rewrite")
                return REWRITER_RESULT.FAIL

            self._connect_scan_with_output(context, scan_node)

            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error(
                "custom rnn rewrite failed, due to exception: %s, details:%s",
                ex, tb)
            return REWRITER_RESULT.FAIL
Example #2
0
    def rewrite(self, context):
        log.debug("enter rewrite function")
        try:
            scan_props = context.loop_properties
            nodes_to_append = []

            state_inputs_initial_values = []
            for state_input in scan_props.state_inputs_initial_values:
                nodes = self._adapt_scan_sequence_input_or_output(
                    "input", state_input, False)
                state_inputs_initial_values.append(nodes[-1].output[0])
                nodes_to_append.extend(nodes)

            scan_inputs_initial_values = []
            for scan_input in scan_props.scan_inputs_initial_values:
                nodes = self._adapt_scan_sequence_input_or_output(
                    "input", scan_input, False)
                scan_inputs_initial_values.append(nodes[-1].output[0])
                nodes_to_append.extend(nodes)

            cell_g_info = context.cell_graph
            scan_body_g = LoopRewriterBase.construct_graph_from_nodes(
                self.g, cell_g_info.nodes, cell_g_info.outputs)
            for input_tensor_info in scan_props.state_inputs:
                scan_body_g.add_graph_input(input_tensor_info.id,
                                            input_tensor_info.dtype,
                                            input_tensor_info.shape)

            for input_tensor_info in scan_props.scan_inputs:
                scan_body_g.add_graph_input(input_tensor_info.id,
                                            input_tensor_info.dtype,
                                            input_tensor_info.shape)

            scan_node = self._create_scan_node(
                context, scan_props,
                state_inputs_initial_values + scan_inputs_initial_values)
            if not scan_node:
                log.error("failed to create scan node during rewrite")
                return REWRITER_RESULT.FAIL

            scan_node.set_body_graph_as_attr("body", scan_body_g)
            nodes_to_append.append(scan_node)
            to_append = self._connect_scan_with_output(context, scan_node)
            nodes_to_append.extend(to_append)

            all_nodes = self.g.get_nodes()
            all_nodes.extend(nodes_to_append)
            self.g.set_nodes(all_nodes)

            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            log.error(
                "custom rnn rewrite failed, due to exception: %s, details:%s",
                ex, tb)
            return REWRITER_RESULT.FAIL
Example #3
0
    def rewrite(self, context):
        logger.debug("enter rewrite function")
        loop_node = None
        try:
            loop_props = context.loop_properties
            cell_g_info = context.cell_graph
            cond_g_info = context.cond_graph

            # create a dummy loop to calculate the init condition
            init_cond_output = self._create_subgraph_initial_cond(cond_g_info)

            ## create Loop body graph with existing nodes

            body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
            body_outputs = cond_g_info.outputs + cell_g_info.outputs
            for out_tensor_value_info in body_outputs:
                shape = out_tensor_value_info.shape
                utils.make_sure(
                    shape is not None,
                    "Conversion of Loop requries output shape [{}] exists".
                    format(out_tensor_value_info.id))
                out_tensor_value_info.shape = utils.create_vague_shape_like(
                    shape)

            loop_body_g = LoopRewriterBase.construct_graph_from_nodes(
                self.g, body_nodes, body_outputs)

            # create loop body graph inputs
            loop_body_g.add_graph_input(utils.make_name("i"),
                                        TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"),
                                        TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(
                    input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                index_node = loop_body_g.make_node("Unsqueeze",
                                                   [input_ta.index_input_id],
                                                   attr={"axes": [0]})
                gather_node = loop_body_g.make_node(
                    "Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = loop_body_g.make_node("Squeeze",
                                                  [gather_node.output[0]],
                                                  attr={"axes": [0]})
                loop_body_g.replace_all_inputs(
                    input_ta.consumer.id,
                    data_node.output[0])  # ops=loop_body_g.get_nodes()

            ## create Loop node
            branches = {"body": loop_body_g}
            loop_node = self._create_loop_node(context,
                                               loop_props,
                                               init_cond_output,
                                               branches=branches)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error(
                "loop rewrite failed, due to exception: %s, details:%s", ex,
                tb)
            return REWRITER_RESULT.FAIL
Example #4
0
    def rewrite(self, context):
        logger.debug("enter rewrite function")
        loop_node = None
        try:
            loop_props = context.loop_properties
            cell_g_info = context.cell_graph
            cond_g_info = context.cond_graph

            # todo(pengwa): we don't check the case where loop body won't be executed at all.

            ## create Loop body graph with existing nodes

            # replace condition graph's inputs to be cell graph's outputs, because we want condition graph
            # to consumer cell graph outputs.
            for loop_var in cond_g_info.dependent_vars:
                self.g.replace_all_inputs(
                    cond_g_info.nodes, loop_var.switch_true_identity_output.id,
                    loop_var.next_iteration_input.id)

            body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
            body_outputs = cond_g_info.outputs + cell_g_info.outputs
            for out_tensor_value_info in body_outputs:
                out_tensor_value_info.shape = utils.create_vague_shape_like(
                    out_tensor_value_info.shape)

            loop_body_g = LoopRewriterBase.construct_graph_from_nodes(
                self.g, body_nodes, body_outputs)

            # create loop body graph inputs
            loop_body_g.add_graph_input(utils.make_name("i"),
                                        TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"),
                                        TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(
                    input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                index_node = loop_body_g.make_node("Unsqueeze",
                                                   [input_ta.index_input_id],
                                                   attr={"axes": [0]})
                gather_node = loop_body_g.make_node(
                    "Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = loop_body_g.make_node("Squeeze",
                                                  [gather_node.output[0]],
                                                  attr={"axes": [0]})
                loop_body_g.replace_all_inputs(loop_body_g.get_nodes(),
                                               input_ta.consumer.id,
                                               data_node.output[0])

            ## create Loop node
            loop_node = self._create_loop_node(context, loop_props)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL
            loop_node.set_body_graph_as_attr("body", loop_body_g)

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error(
                "loop rewrite failed, due to exception: %s, details:%s", ex,
                tb)
            return REWRITER_RESULT.FAIL