Exemple #1
0
    def _connect_lstm_ych_to_graph(self, context, i):
        # in tf, concat of y_c and y_h output shape is: [batch, hidden *2]
        # in onnx, y_c/y_h output shape is: [number_directions, batch, hidden]
        gb = GraphBuilder(self.g)
        exit_output = context.state_variables["ct_ht" + str(i)].exit_output
        lstm_node = context.rnn_node[i]
        yc_shape = self.g.get_shape(lstm_node.output[2])
        concat_output_shape = [yc_shape[0], yc_shape[1], yc_shape[2] * 2]
        concat = self.g.make_node(
            "Concat", [lstm_node.output[2], lstm_node.output[1]],
            attr={"axis": 2},
            shapes=[concat_output_shape],
            dtypes=[self.g.get_dtype(lstm_node.output[2])])

        squeeze_output_shape = [concat_output_shape[1], concat_output_shape[2]]
        squeeze_node = gb.make_squeeze(
            {
                'data': concat.output[0],
                "axes": [0]
            },
            shapes=[squeeze_output_shape],
            dtypes=[self.g.get_dtype(concat.output[0])],
            return_node=True)

        self.g.replace_all_inputs(
            exit_output.id, squeeze_node.output[0])  # ops=self.g.get_nodes()
    def connect_unit_rnn_output_to_graph(self, context):
        outputs = context.loop_properties.scan_outputs_exits
        if not outputs:
            logger.debug("no one consume output")
            return

        gather_output_id = outputs[0].id
        logger.debug("found output for rnn: %s", gather_output_id)

        # in tf batch major mode, output shape is : [batch, time, hidden]
        # in time major mode, output shape is: [time, batch, hidden]
        # in onnx, output shape is : [time, num_directions, batch, hidden]

        rnn_node = context.rnn_node
        output_id = rnn_node.output[0]
        rnn_output_shape = self.g.get_shape(output_id)
        squeeze_output_shape = [
            rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3]
        ]
        gb = GraphBuilder(self.g)
        squeeze_node = gb.make_squeeze({
            'data': output_id,
            "axes": [1]
        },
                                       shapes=[squeeze_output_shape],
                                       dtypes=[self.g.get_dtype(output_id)],
                                       return_node=True)
        self.g.replace_all_inputs(
            gather_output_id, squeeze_node.output[0])  # ops=self.g.get_nodes()
Exemple #3
0
    def create_rnn_node(self, context):
        gb = GraphBuilder(self.g)
        rnn_nodes = list()
        outputs = context.loop_properties.scan_outputs_exits
        logger.debug("number of rnn node outputs: %s", len(outputs))

        for i in range(self.num_lstm_layers):
            logger.debug("creating rnn node for layer: %s", i)
            rnn_nodes.append(self.create_single_rnn_node(context, i))
            output_id = rnn_nodes[i].output[0]
            rnn_output_shape = self.g.get_shape(output_id)
            squeeze_output_shape = [
                rnn_output_shape[0], rnn_output_shape[2], rnn_output_shape[3]
            ]
            squeeze_node = gb.make_squeeze(
                {
                    "data": output_id,
                    "axes": [1]
                },
                shapes=[squeeze_output_shape],
                dtypes=[self.g.get_dtype(output_id)],
                return_node=True)
            if i + 1 < self.num_lstm_layers:
                logger.debug("setting input for layer: %s", i + 1)
                context.onnx_input_ids[i + 1]["X"] = squeeze_node.output[0]
        return rnn_nodes
Exemple #4
0
    def version_13(cls, ctx, node, **kwargs):
        ctx.ta_reads.append(node.input[0])
        node.type = "Gather"
        ctx.replace_inputs(node, [node.input[0], node.input[1]])

        g = GraphBuilder(ctx)

        usq_node = g.make_unsqueeze({"axes": [0], 'name': node.child_name(), 'data': node.input[1]}, return_node=True)
        ctx.insert_node_on_output(usq_node)

        sq_node = g.make_squeeze({"axes": [0], 'name': node.child_name(), 'data': node.output[0]}, return_node=True)
        ctx.insert_node_on_output(sq_node)
Exemple #5
0
    def _connect_lstm_yc_to_graph(self, context, i):
        # in tf, y_c output shape is: [batch, hidden]
        # in onnx, output shape is: [number_directions, batch, hidden]
        gb = GraphBuilder(self.g)
        exit_output = context.state_variables["ct" + str(i)].exit_output
        output_id = context.rnn_node[i].output[2]
        lstm_yc_shape = self.g.get_shape(output_id)
        squeeze_node = gb.make_squeeze(
            {
                "data": output_id,
                "axes": [0]
            },
            shapes=[[lstm_yc_shape[1], lstm_yc_shape[2]]],
            dtypes=[self.g.get_dtype(output_id)],
            return_node=True)

        self.g.replace_all_inputs(
            exit_output.id, squeeze_node.output[0])  # ops=self.g.get_nodes()
Exemple #6
0
    def version_10(cls, ctx, node, **kwargs):
        x = node.input[0]
        x_shape = ctx.get_shape(x)
        h = node.input[1]
        h_shape = ctx.get_shape(h)
        p = node.input[3]
        utils.make_sure(node.attr["rnn_mode"].s == b"gru",
                        "rnn mode other than gru are not supported yet")
        utils.make_sure(node.attr["dropout"].f == 0,
                        "dropout not supported yet")
        utils.make_sure(node.attr["input_mode"].s == b"linear_input",
                        "input mode must be linear input")
        num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2
        num_layers = int(h_shape[0] / num_dirs)
        num_units = hidden_size = h_shape[2]
        input_size = x_shape[2]
        w_shape = [num_layers * num_dirs, 3 * hidden_size, input_size]
        w_shape_const = ctx.make_const(utils.make_name("w_shape"),
                                       np.array(w_shape, dtype=np.int64))
        r_shape = [num_layers * num_dirs, 3 * hidden_size, hidden_size]
        r_shape_const = ctx.make_const(utils.make_name("r_shape"),
                                       np.array(r_shape, dtype=np.int64))
        b_shape = [num_layers * num_dirs, 6 * hidden_size]
        b_shape_const = ctx.make_const(utils.make_name("b_shape"),
                                       np.array(b_shape, dtype=np.int64))
        zero_const = ctx.make_const(utils.make_name("zero"),
                                    np.array([0], dtype=np.int64))
        w_end = np.prod(w_shape)
        w_end_const = ctx.make_const(utils.make_name("w_end"),
                                     np.array([w_end], dtype=np.int64))
        r_end = w_end + np.prod(r_shape)
        r_end_const = ctx.make_const(utils.make_name("r_end"),
                                     np.array([r_end], dtype=np.int64))
        b_end = r_end + np.prod(b_shape)
        b_end_const = ctx.make_const(utils.make_name("b_end"),
                                     np.array([b_end], dtype=np.int64))

        def name(nm):
            return node.name + "_" + nm

        ws = [name('W_' + str(i)) for i in range(num_layers * num_dirs)]
        rs = [name('R_' + str(i)) for i in range(num_layers * num_dirs)]
        bs = [name('B_' + str(i)) for i in range(num_layers * num_dirs)]
        hs = [name('H_' + str(i)) for i in range(num_layers * num_dirs)]
        yhs = [name('YH_' + str(i)) for i in range(num_layers * num_dirs)]
        w_flattened = ctx.make_node(
            'Slice', [p, zero_const.output[0], w_end_const.output[0]])
        r_flattened = ctx.make_node(
            'Slice', [p, w_end_const.output[0], r_end_const.output[0]])
        b_flattened = ctx.make_node(
            'Slice', [p, r_end_const.output[0], b_end_const.output[0]])
        w = utils.make_name('W')
        r = utils.make_name('R')
        b = utils.make_name('B')
        ctx.make_node('Reshape',
                      [w_flattened.output[0], w_shape_const.output[0]],
                      outputs=[w])
        ctx.make_node('Reshape',
                      [r_flattened.output[0], r_shape_const.output[0]],
                      outputs=[r])
        ctx.make_node('Reshape',
                      [b_flattened.output[0], b_shape_const.output[0]],
                      outputs=[b])
        ctx.make_node('Split', [w], outputs=ws)
        ctx.make_node('Split', [r], outputs=rs)
        ctx.make_node('Split', [b], outputs=bs)
        ctx.make_node('Split', [h], outputs=hs)

        builder = GraphBuilder(ctx)

        xnf = xnb = x
        for i in range(num_layers):
            suffix = '_' + str(i * num_dirs)
            ctx.make_node('GRU', [
                xnf,
                name('W' + suffix),
                name('R' + suffix),
                name('B' + suffix), '',
                name('H' + suffix)
            ],
                          outputs=[name('Y' + suffix),
                                   name('YH' + suffix)],
                          attr={
                              'direction': 'forward',
                              'hidden_size': num_units
                          })
            xnf = name(x + suffix)
            builder.make_squeeze({
                'data': name('Y' + suffix),
                'outputs': [xnf],
                'axes': [1]
            })
            if num_dirs == 2:
                suffix = '_' + str(i * 2 + 1)
                ctx.make_node(
                    'GRU', [
                        xnb,
                        name('W' + suffix),
                        name('R' + suffix),
                        name('B' + suffix), '',
                        name('H' + suffix)
                    ],
                    outputs=[name('Y' + suffix),
                             name('YH' + suffix)],
                    attr={
                        'direction': 'reverse',
                        'hidden_size': num_units
                    })
                xnb = name(x + suffix)
                builder.make_squeeze({
                    'data': name('Y' + suffix),
                    'outputs': [xnb],
                    'axes': [1]
                })
        ctx.remove_node(node.name)
        if num_dirs == 2:
            ctx.make_node('Concat', [xnf, xnb],
                          outputs=[node.output[0]],
                          attr={'axis': -1})
        else:
            ctx.make_node('Identity', [xnf], outputs=[node.output[0]])
        ctx.make_node('Concat',
                      yhs,
                      outputs=[node.output[1]],
                      attr={'axis': 0})
    def rewrite(self, context):
        logger.debug("enter rewrite function")
        loop_node = None
        try:
            loop_props = context.loop_properties
            cell_g_info = context.cell_graph
            cond_g_info = context.cond_graph

            # create a dummy loop to calculate the init condition
            init_cond_output = self._create_subgraph_initial_cond(cond_g_info)

            ## create Loop body graph with existing nodes

            body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
            body_outputs = cond_g_info.outputs + cell_g_info.outputs
            for out_tensor_value_info in body_outputs:
                shape = out_tensor_value_info.shape
                utils.make_sure(
                    shape is not None,
                    "Conversion of Loop requries output shape [{}] exists".format(out_tensor_value_info.id)
                )
                out_tensor_value_info.shape = utils.create_vague_shape_like(shape)

            loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs)

            # create loop body graph inputs
            loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                gb = GraphBuilder(loop_body_g)
                index_node = gb.make_unsqueeze({'data': input_ta.index_input_id, "axes": [0]}, return_node=True)
                gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = gb.make_squeeze({'data': gather_node.output[0], "axes": [0]}, return_node=True)
                loop_body_g.replace_all_inputs(input_ta.consumer.id, data_node.output[0])  # ops=loop_body_g.get_nodes()

            ## create Loop node
            branches = {"body": loop_body_g}
            loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb)
            return REWRITER_RESULT.FAIL