コード例 #1
0
    def _process_non_tuple_ch_init_nodes(self, context, i):
        gb = GraphBuilder(self.g)
        input_id = context.state_variables["ct_ht" + str(i)].enter_input_id
        hidden_size = context.hidden_size[i]

        attr = {"axes": [1], "starts": [0], "ends": [hidden_size]}
        inputs_map = {"data": input_id, **attr}
        slice_node1 = GraphBuilder(self.g).make_slice(inputs_map)
        unsqueeze_node_1 = gb.make_unsqueeze({
            'data': slice_node1,
            "axes": [0]
        },
                                             return_node=True)

        attr = {
            "axes": [1],
            "starts": [hidden_size],
            "ends": [hidden_size * 2]
        }
        inputs_map = {"data": input_id, **attr}
        slice_node2 = GraphBuilder(self.g).make_slice(inputs_map)
        unsqueeze_node_2 = gb.make_unsqueeze({
            'data': slice_node2,
            "axes": [0]
        },
                                             return_node=True)

        return unsqueeze_node_1.output[0], unsqueeze_node_2.output[0]
コード例 #2
0
    def version_13(cls, ctx, node, **kwargs):
        ctx.ta_reads.append(node.input[0])
        node.type = "Gather"
        ctx.replace_inputs(node, [node.input[0], node.input[1]])

        g = GraphBuilder(ctx)

        usq_node = g.make_unsqueeze({"axes": [0], 'name': node.child_name(), 'data': node.input[1]}, return_node=True)
        ctx.insert_node_on_output(usq_node)

        sq_node = g.make_squeeze({"axes": [0], 'name': node.child_name(), 'data': node.output[0]}, return_node=True)
        ctx.insert_node_on_output(sq_node)
コード例 #3
0
    def _process_c_or_h_init_nodes(self, initializer_input_id, context):
        node = self.g.get_node_by_output(initializer_input_id)
        if node.is_const():
            val = node.get_tensor_value(as_list=False)
            initial_name = utils.make_name("Const")
            new_val = np.expand_dims(val, axis=0)
            const_node = self.g.make_const(initial_name, new_val)
            return const_node.output[0]

        gb = GraphBuilder(self.g)
        squeeze_node = gb.make_unsqueeze(
            {
                'data': initializer_input_id,
                "axes": [0]
            }, return_node=True)
        to_replace = [n for n in self.g.get_nodes() if n != squeeze_node]
        self.g.replace_all_inputs(initializer_input_id,
                                  squeeze_node.output[0],
                                  ops=to_replace)
        return squeeze_node.output[0]
コード例 #4
0
    def rewrite(self, context):
        logger.debug("enter rewrite function")
        loop_node = None
        try:
            loop_props = context.loop_properties
            cell_g_info = context.cell_graph
            cond_g_info = context.cond_graph

            # create a dummy loop to calculate the init condition
            init_cond_output = self._create_subgraph_initial_cond(cond_g_info)

            ## create Loop body graph with existing nodes

            body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
            body_outputs = cond_g_info.outputs + cell_g_info.outputs
            for out_tensor_value_info in body_outputs:
                shape = out_tensor_value_info.shape
                utils.make_sure(
                    shape is not None,
                    "Conversion of Loop requries output shape [{}] exists".format(out_tensor_value_info.id)
                )
                out_tensor_value_info.shape = utils.create_vague_shape_like(shape)

            loop_body_g = LoopRewriterBase.construct_graph_from_nodes(self.g, body_nodes, body_outputs)

            # create loop body graph inputs
            loop_body_g.add_graph_input(utils.make_name("i"), TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"), TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                gb = GraphBuilder(loop_body_g)
                index_node = gb.make_unsqueeze({'data': input_ta.index_input_id, "axes": [0]}, return_node=True)
                gather_node = loop_body_g.make_node("Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = gb.make_squeeze({'data': gather_node.output[0], "axes": [0]}, return_node=True)
                loop_body_g.replace_all_inputs(input_ta.consumer.id, data_node.output[0])  # ops=loop_body_g.get_nodes()

            ## create Loop node
            branches = {"body": loop_body_g}
            loop_node = self._create_loop_node(context, loop_props, init_cond_output, branches=branches)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error("loop rewrite failed, due to exception: %s, details:%s", ex, tb)
            return REWRITER_RESULT.FAIL