Exemplo n.º 1
0
    def _create_loop_node(self, context, loop_props):
        loop_outputs = []
        loop_output_shapes = []
        loop_output_dtypes = []
        for tensor_value_info in loop_props.state_outputs_exits + loop_props.scan_outputs_exits:
            if tensor_value_info.id:
                loop_outputs.append(tensor_value_info.id)
                loop_output_shapes.append(tensor_value_info.shape)
                loop_output_dtypes.append(tensor_value_info.dtype)
                n = self.g.get_node_by_output(tensor_value_info.id)
                self.g.remove_node(n.name)
            else:
                loop_outputs.append(utils.make_name("unused_loop_output_"))
                loop_output_shapes.append([-1])
                loop_output_dtypes.append(None)

        # trip count and cond are not used, giving them values just because bug
        # (https://github.com/Microsoft/onnxruntime/issues/255) of onnxruntime.
        trip_cnt = self.g.make_const(utils.make_name("trip_count"),
                                     np.array(sys.maxsize, dtype=np.int64))
        cond = self.g.make_const(utils.make_name("cond"),
                                 np.array(True, dtype=np.bool))
        loop_node = self.g.make_node(
            "Loop",
            [trip_cnt.output[0]] + [cond.output[0]] + loop_props.
            state_inputs_initial_values,  # ONNX Loop support state inputs only
            outputs=loop_outputs,
            op_name_scope="generic_loop",
            shapes=loop_output_shapes,
            dtypes=loop_output_dtypes,
            skip_conversion=False)

        return loop_node
    def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_output=False):
        nodes_to_add = []
        shape_node = self.g.make_node("Shape", [input_id])
        nodes_to_add.append(shape_node)
        inferred_shape = self.g.get_shape(input_id)
        if handle_output is True:
            # handle output:
            # if required dim values don't contain more than one -1,
            # just use a const for Reshape's shape input.
            if inferred_shape is not None and inferred_shape[1:].count(-1) <= 1:
                new_shape_node = self.g.make_const(utils.make_name(target_name + "_target_shape"),
                                                   np.array(inferred_shape[1:], dtype=np.int64))
            else:
                # otherwise, get the dim dynamically, e.g. remove the fake batch size (e.g.1)
                # from [1, time, real-batch, ...]
                origin_shape_node = self.g.make_node("Cast", [shape_node.output[0]],
                                                     {"to": onnx_pb.TensorProto.FLOAT})
                nodes_to_add.append(origin_shape_node)

                sliced_shape_node = self.g.make_node("Slice", [origin_shape_node.output[0]],
                                                     {"axes": [0], "starts": [1], "ends": [sys.maxsize]})
                nodes_to_add.append(sliced_shape_node)

                new_shape_node = self.g.make_node("Cast", [sliced_shape_node.output[0]],
                                                  {"to": onnx_pb.TensorProto.INT64})
                nodes_to_add.append(new_shape_node)

            new_shape = inferred_shape[1:]
        else:
            # handle input:
            if inferred_shape is not None and inferred_shape.count(-1) <= 1:
                new_shape_node = self.g.make_const(utils.make_name(target_name + "_target_shape"),
                                                   np.array([1] + inferred_shape, dtype=np.int64))
            else:
                # add a fake batch size : 1
                fake_batch_size_node = self.g.make_const(utils.make_name(target_name + "_target_shape"),
                                                         np.array([1,], dtype=np.int64))
                new_shape_node = self.g.make_node("Concat",
                                                  [fake_batch_size_node.output[0], shape_node.output[0]],
                                                  attr={"axis": 0})
                nodes_to_add.append(new_shape_node)
            new_shape = [1] + inferred_shape

        reshape_node = self.g.make_node("Reshape", [input_id, new_shape_node.output[0]],
                                        shapes=[new_shape],
                                        dtypes=[self.g.get_dtype(input_id)],
                                        op_name_scope=target_name)
        nodes_to_add.append(reshape_node)
        log.debug("create Reshape for scan output %s, with output shape %s",
                  reshape_node.output[0], new_shape)
        return nodes_to_add
    def _create_loop_node(self, context, loop_props):
        # reuse original output connection id (e.g. Exit_XXX), so we don't need set shape.
        loop_outputs = []
        for tensor_value_info in loop_props.state_outputs_exits + loop_props.scan_outputs_exits:
            if tensor_value_info.id:
                loop_outputs.append(tensor_value_info.id)
            else:
                loop_outputs.append(utils.make_name("unused_loop_output_"))

        # trip count and cond are not used, giving them values just because bug
        # (https://github.com/Microsoft/onnxruntime/issues/255) of onnxruntime.
        trip_cnt = self.g.make_const(utils.make_name("trip_count"), np.array(sys.maxsize, dtype=np.int64))
        cond = self.g.make_const(utils.make_name("cond"), np.array(True, dtype=np.bool))
        loop_node = self.g.make_node("Loop", [trip_cnt.output[0]] + [cond.output[0]] +
                                     loop_props.state_inputs_initial_values,  # ONNX Loop support state inputs only
                                     outputs=loop_outputs, op_name_scope="generic_loop",
                                     skip_conversion=False)

        return loop_node
    def rewrite(self):
        log.debug("enter custom rnn late rewriter")
        nodes = self.g.get_nodes()
        nodes_to_remove = []
        for scan_node in nodes:
            if scan_node.type != "Scan":
                continue
            log.debug("late write for scan node %s", scan_node.name)
            num_scan_inputs = scan_node.get_attr("num_scan_inputs").i
            if not BodyGraphDict.has_body_graph_info(scan_node.name):
                continue

            body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name)
            onnx_nodes, _ = LoopRewriterBase.find_subgraph(
                body_graph_meta, self.g)
            nodes_to_remove.extend(onnx_nodes)

            log.debug("start creating body graph for scan node %s ",
                      scan_node.name)
            body_graph_initializers = {}
            const_nodes = [
                n for n in onnx_nodes if n.type in ("Const", "ConstV2")
            ]
            for n in const_nodes:
                # when set nodes, Const should be removed, they need be replaced as initializers.
                body_graph_initializers[n.output[0]] = self.g.initializers[
                    n.output[0]]
                onnx_nodes.remove(n)

            onnx_nodes = set(onnx_nodes)

            ops = []
            for op in onnx_nodes:
                onnx_op = op.op
                ops.append(onnx_op)

            body_g = Graph(ops,
                           output_shapes=self.g._output_shapes,
                           dtypes=self.g._dtypes)
            body_g._initializers = body_graph_initializers

            log.debug("start preparing body graph inputs nodes")
            temp_nodes = body_g.get_nodes()
            i = 0
            input_count = len(body_graph_meta.input_ids)
            for input_name, init_input_id in zip(
                    body_graph_meta.input_ids,
                    body_graph_meta.initial_input_ids):
                shape = body_g.get_shape(input_name)
                dtype = body_g.get_dtype(input_name)
                if shape is None:
                    shape = self.g.get_shape(init_input_id)
                    if i >= input_count - num_scan_inputs:
                        loop_input_shape = list(shape)[2:]  # delete [1, time,]
                    else:
                        loop_input_shape = list(shape)
                else:
                    loop_input_shape = list(shape)

                onnx_input_shape = utils.make_onnx_shape(loop_input_shape)
                val = helper.make_tensor_value_info(input_name, dtype,
                                                    onnx_input_shape)
                body_g.add_model_input(input_name, val)
                i += 1

            log.debug("start preparing body graph outputs nodes")
            new_output_names = []
            for o in body_graph_meta.output_ids:
                # insert identity node, since sometimes we need output same output_id as state_output
                # and scan_out, but ONNX don't allow the same output_id appeared more than once as
                # output node.
                identity_name = utils.make_name("Identity")
                identity_output = utils.port_name(identity_name)
                node = Node(
                    helper.make_node("Identity", [o], [identity_output],
                                     name=identity_name), body_g)
                body_g.set_dtype(identity_output, body_g.get_dtype(o))
                body_g.copy_shape(o, identity_output)
                new_output_names.append(identity_output)
                temp_nodes.append(node)

            body_g.set_nodes(temp_nodes)
            body_g.topological_sort(body_g.get_nodes())

            log.debug("start make graph based on body graph nodes")
            body_g.output_names = new_output_names
            graph = body_g.make_graph("scan body graph")
            scan_node.set_attr("body", graph)

        # remove nodes in body graph from g
        for n in set(nodes_to_remove):
            if n in nodes:
                nodes.remove(n)
            elif self.g.is_initializer(n.output[0]):
                del self.g.initializers[n.output[0]]
            else:
                raise ValueError("error when removing nodes")

        return nodes
Exemplo n.º 5
0
    def rewrite(self, context):
        log.debug("enter rewrite function")
        loop_node = None
        try:
            loop_props = context.loop_properties
            cell_g_info = context.cell_graph
            cond_g_info = context.cond_graph

            # todo(pengwa): we don't check the case where loop body won't be executed at all.

            ## create Loop body graph with existing nodes

            # replace condition graph's inputs to be cell graph's outputs, because we want condition graph
            # to consumer cell graph outputs.
            for loop_var in cond_g_info.dependent_vars:
                self.g.replace_all_inputs(
                    cond_g_info.nodes, loop_var.switch_true_identity_output.id,
                    loop_var.next_iteration_input.id)

            body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
            body_outputs = cond_g_info.outputs + cell_g_info.outputs
            for out_tensor_value_info in body_outputs:
                out_tensor_value_info.shape = utils.create_vague_shape_like(
                    out_tensor_value_info.shape)

            loop_body_g = LoopRewriterBase.construct_graph_from_nodes(
                self.g, body_nodes, body_outputs)

            # create loop body graph inputs
            loop_body_g.add_graph_input(utils.make_name("i"),
                                        TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"),
                                        TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(
                    input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                index_node = loop_body_g.make_node("Unsqueeze",
                                                   [input_ta.index_input_id],
                                                   attr={"axes": [0]})
                gather_node = loop_body_g.make_node(
                    "Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = loop_body_g.make_node("Squeeze",
                                                  [gather_node.output[0]],
                                                  attr={"axes": [0]})
                loop_body_g.replace_all_inputs(loop_body_g.get_nodes(),
                                               input_ta.consumer.id,
                                               data_node.output[0])

            ## create Loop node
            loop_node = self._create_loop_node(context, loop_props)
            if not loop_node:
                log.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL
            loop_node.set_body_graph_as_attr("body", loop_body_g)

            log.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            log.error("loop rewrite failed, due to exception: %s, details:%s",
                      ex, tb)
            return REWRITER_RESULT.FAIL