コード例 #1
0
 def _get_output_shape_dtype(self, cond_context):
     output_shapes = []
     output_dtypes = []
     for i, _ in enumerate(cond_context.true_branch_context.output):
         true_output = cond_context.true_branch_context.output[i]
         false_output = cond_context.false_branch_context.output[i]
         true_shape = self.g.get_shape(true_output)
         utils.make_sure(true_shape is not None,
                         "Shape of {} is None".format(true_output))
         true_rank = len(true_shape)
         true_dtype = self.g.get_dtype(true_output)
         false_shape = self.g.get_shape(false_output)
         utils.make_sure(false_shape is not None,
                         "Shape of {} is None".format(false_output))
         false_rank = len(false_shape)
         false_dtype = self.g.get_dtype(false_output)
         # just require rank is equal
         if true_rank != false_rank:
             raise RuntimeError(
                 "the rank of outputs {} and {} mismatch: {}, {}".format(
                     true_output, false_output, true_rank, false_rank))
         if true_dtype != false_dtype:
             raise RuntimeError(
                 "the dtype of outputs {} and {} mismatch: {}, {}".format(
                     true_output, false_output, true_dtype, false_dtype))
         output_shapes.append(utils.create_vague_shape_like(true_shape))
         output_dtypes.append(true_dtype)
     return output_shapes, output_dtypes
コード例 #2
0
 def _get_output_shape_dtype(self, cond_context):
     output_shapes = []
     output_dtypes = []
     for i, _ in enumerate(cond_context.true_branch_context.output):
         true_output = cond_context.true_branch_context.output[i]
         false_output = cond_context.false_branch_context.output[i]
         true_shape = self.g.get_shape(true_output)
         true_dtype = self.g.get_dtype(true_output)
         false_shape = self.g.get_shape(false_output)
         false_dtype = self.g.get_dtype(false_output)
         if not utils.are_shapes_compatible(true_shape, false_shape):
             raise RuntimeError(
                 "the shape of outputs {} and {} mismatch: {}, {}".format(
                     true_output,
                     false_output,
                     true_shape,
                     false_shape
                 )
             )
         if true_dtype != false_dtype:
             raise RuntimeError(
                 "the dtype of outputs {} and {} mismatch: {}, {}".format(
                     true_output,
                     false_output,
                     true_dtype,
                     false_dtype
                 )
             )
         # in tf, the shape of different branched can be different,
         # for example output shape of branch A can be [-1] while branch B can be [1].
         # Under this case, we should set output shape to be [-1]
         output_shapes.append(utils.create_vague_shape_like(utils.merge_shapes(true_shape, false_shape)))
         output_dtypes.append(true_dtype)
     return output_shapes, output_dtypes
コード例 #3
0
ファイル: select.py プロジェクト: skottmckay/tensorflow-onnx
def create_body_graph_for_if_branch(parent_g, data_type, output_shape, chosen_cur_cond_val_out_name, op_name):
    g = parent_g.create_new_graph_with_same_config()
    name = utils.make_name("Identity")
    g.make_node(
        'Identity',
        inputs=[chosen_cur_cond_val_out_name],
        outputs=['y'],
        name=name
    )
    g.add_graph_output("y", data_type, utils.create_vague_shape_like(output_shape))
    return g
コード例 #4
0
ファイル: controlflow.py プロジェクト: wkkyle/tensorflow-onnx
def create_loop_body_graph(parent_g, gather_input_ids, output_data_type, output_shape, trip_count_input_ids,
                           rank, loop_name):
    g = parent_g.create_new_graph_with_same_config()
    g.parent_graph = parent_g
    iter_name = utils.make_name("i")
    cond_name = utils.make_name("cond")
    fake_var_name = utils.make_name("fake_var")

    g.add_graph_input(iter_name, TensorProto.INT64, (1,))  # iteration_num
    g.add_graph_input(cond_name, TensorProto.BOOL, ())  # condition
    g.add_graph_input(fake_var_name, TensorProto.FLOAT, ())  # loop-carried dependency

    # get the i'th value of condition
    cond_input_id = gather_input_ids[0]
    cond_input_id_for_current_iter = get_inputs_for_current_iteration(g, cond_input_id, iter_name)

    # get the i'th value of true values
    true_input_id = gather_input_ids[1]
    true_input_id_for_current_iter = get_inputs_for_current_iteration(g, true_input_id, iter_name)

    # get the i'th value of false values
    false_input_id = gather_input_ids[2]
    false_input_id_for_current_iter = get_inputs_for_current_iteration(g, false_input_id, iter_name)

    input_ids_for_current_iter = [cond_input_id_for_current_iter, true_input_id_for_current_iter,
                                  false_input_id_for_current_iter]
    output_id = None
    rank -= 1
    if rank >= 1:
        loop_1 = create_loop_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:],
                                trip_count_input_ids, rank)
        output_id = loop_1.output[1]
    elif rank == 0:
        _, if_node_output_id = create_if_op(g, input_ids_for_current_iter, output_data_type, output_shape[1:])
        output_id = if_node_output_id

    output_identity_name = utils.make_name("loop_output")
    loop_output_id = utils.port_name(output_identity_name)
    g.make_node(
        'Identity',
        [output_id],
        outputs=[loop_output_id],
        name=output_identity_name
    )

    cond_identity_name = utils.make_name("cond_output")
    cond_output_id = utils.port_name(cond_identity_name)
    g.make_node(
        'Identity',
        [cond_name],
        outputs=[cond_output_id],
        name=cond_identity_name
    )

    fake_var_identity_name = utils.make_name("fake_var_output")
    fake_var_output_id = utils.port_name(fake_var_identity_name)
    g.make_node(
        'Identity',
        [fake_var_name],
        outputs=[fake_var_output_id],
        name=fake_var_identity_name
    )

    g.add_graph_output(cond_output_id, TensorProto.BOOL, ())
    g.add_graph_output(fake_var_output_id, TensorProto.FLOAT, ())

    # use None for all dims, just keep original rank. Because it is observed, dims might be changed in loop.
    g.add_graph_output(loop_output_id, output_data_type, utils.create_vague_shape_like(output_shape[1:]))

    return g
コード例 #5
0
    def rewrite(self, context):
        logger.debug("enter rewrite function")
        loop_node = None
        try:
            loop_props = context.loop_properties
            cell_g_info = context.cell_graph
            cond_g_info = context.cond_graph

            # create a dummy loop to calculate the init condition
            init_cond_output = self._create_subgraph_initial_cond(cond_g_info)

            ## create Loop body graph with existing nodes

            body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
            body_outputs = cond_g_info.outputs + cell_g_info.outputs
            for out_tensor_value_info in body_outputs:
                shape = out_tensor_value_info.shape
                utils.make_sure(
                    shape is not None,
                    "Conversion of Loop requries output shape [{}] exists".
                    format(out_tensor_value_info.id))
                out_tensor_value_info.shape = utils.create_vague_shape_like(
                    shape)

            loop_body_g = LoopRewriterBase.construct_graph_from_nodes(
                self.g, body_nodes, body_outputs)

            # create loop body graph inputs
            loop_body_g.add_graph_input(utils.make_name("i"),
                                        TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"),
                                        TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(
                    input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                index_node = loop_body_g.make_node("Unsqueeze",
                                                   [input_ta.index_input_id],
                                                   attr={"axes": [0]})
                gather_node = loop_body_g.make_node(
                    "Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = loop_body_g.make_node("Squeeze",
                                                  [gather_node.output[0]],
                                                  attr={"axes": [0]})
                loop_body_g.replace_all_inputs(
                    input_ta.consumer.id,
                    data_node.output[0])  # ops=loop_body_g.get_nodes()

            ## create Loop node
            branches = {"body": loop_body_g}
            loop_node = self._create_loop_node(context,
                                               loop_props,
                                               init_cond_output,
                                               branches=branches)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error(
                "loop rewrite failed, due to exception: %s, details:%s", ex,
                tb)
            return REWRITER_RESULT.FAIL
コード例 #6
0
    def rewrite(self, context):
        logger.debug("enter rewrite function")
        loop_node = None
        try:
            loop_props = context.loop_properties
            cell_g_info = context.cell_graph
            cond_g_info = context.cond_graph

            # todo(pengwa): we don't check the case where loop body won't be executed at all.

            ## create Loop body graph with existing nodes

            # replace condition graph's inputs to be cell graph's outputs, because we want condition graph
            # to consumer cell graph outputs.
            for loop_var in cond_g_info.dependent_vars:
                self.g.replace_all_inputs(
                    cond_g_info.nodes, loop_var.switch_true_identity_output.id,
                    loop_var.next_iteration_input.id)

            body_nodes = set(cell_g_info.nodes + cond_g_info.nodes)
            body_outputs = cond_g_info.outputs + cell_g_info.outputs
            for out_tensor_value_info in body_outputs:
                out_tensor_value_info.shape = utils.create_vague_shape_like(
                    out_tensor_value_info.shape)

            loop_body_g = LoopRewriterBase.construct_graph_from_nodes(
                self.g, body_nodes, body_outputs)

            # create loop body graph inputs
            loop_body_g.add_graph_input(utils.make_name("i"),
                                        TensorProto.INT64, ())
            loop_body_g.add_graph_input(utils.make_name("cond"),
                                        TensorProto.BOOL, ())
            for i, tensor_value_info in enumerate(loop_props.state_inputs):
                input_name = tensor_value_info.id
                if input_name is None:
                    # if the variable is not used in the body graph, then we created a fake one,
                    # the same type and shape as its corresponding output.
                    out_tensor_value_info = loop_props.state_outputs[i]
                    dtype = out_tensor_value_info.dtype
                    shape = out_tensor_value_info.shape
                    input_name = utils.make_name("unused_state_input_")
                else:
                    dtype = tensor_value_info.dtype
                    shape = tensor_value_info.shape

                loop_body_g.add_graph_input(
                    input_name, dtype, utils.create_vague_shape_like(shape))

            for input_ta in loop_props.tensor_array_inputs:
                # Loop does not have scan inputs, so we use Gather to get data for each iteration.
                index_node = loop_body_g.make_node("Unsqueeze",
                                                   [input_ta.index_input_id],
                                                   attr={"axes": [0]})
                gather_node = loop_body_g.make_node(
                    "Gather", [input_ta.data_input_id, index_node.output[0]])
                data_node = loop_body_g.make_node("Squeeze",
                                                  [gather_node.output[0]],
                                                  attr={"axes": [0]})
                loop_body_g.replace_all_inputs(loop_body_g.get_nodes(),
                                               input_ta.consumer.id,
                                               data_node.output[0])

            ## create Loop node
            loop_node = self._create_loop_node(context, loop_props)
            if not loop_node:
                logger.error("failed to create loop node during rewrite")
                return REWRITER_RESULT.FAIL
            loop_node.set_body_graph_as_attr("body", loop_body_g)

            logger.debug("rewrite successfully")
            return REWRITER_RESULT.OK

        except Exception as ex:
            tb = traceback.format_exc()
            logger.error(
                "loop rewrite failed, due to exception: %s, details:%s", ex,
                tb)
            return REWRITER_RESULT.FAIL