Ejemplo n.º 1
0
    def _normalize_program(self, program, feed_vars, fetch_vars):
        if not isinstance(program, Program):
            raise TypeError(
                "program type must be `fluid.Program`, but received `%s`" %
                type(program))
        if not isinstance(feed_vars, list):
            feed_vars = [feed_vars]
        if not all(isinstance(v, Variable) for v in feed_vars):
            raise TypeError(
                "feed_vars type must be a Variable or a list of Variable.")
        if not isinstance(fetch_vars, list):
            fetch_vars = [fetch_vars]
        if not all(isinstance(v, Variable) for v in fetch_vars):
            raise TypeError(
                "fetch_vars type must be a Variable or a list of Variable.")

        # remind users to set auc_states to 0 if auc op were found.
        for op in program.global_block().ops:
            # clear device of Op
            device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName(
            )
            op._set_attr(device_attr_name, "")
            if op.type == 'auc':
                warnings.warn("Be sure that you have set auc states to 0 "
                              "before saving inference model.")
                break

        # serialize program
        copy_program = program.clone()
        global_block = copy_program.global_block()
        remove_op_idx = []
        for i, op in enumerate(global_block.ops):
            op.desc.set_is_target(False)
            if op.type == "feed" or op.type == "fetch":
                remove_op_idx.append(i)
        for idx in remove_op_idx[::-1]:
            global_block._remove_op(idx)
        copy_program.desc.flush()

        feed_var_names = [var.name for var in feed_vars]
        copy_program = copy_program._prune_with_input(
            feeded_var_names=feed_var_names, targets=fetch_vars)
        copy_program = copy_program._inference_optimize(prune_read_op=True)
        fetch_var_names = [var.name for var in fetch_vars]
        prepend_feed_ops(copy_program, feed_var_names)
        append_fetch_ops(copy_program, fetch_var_names)
        copy_program.desc._set_version()
        return copy_program
Ejemplo n.º 2
0
def _normalize_program(program, feed_vars, fetch_vars):
    """
    optimize program according feed_vars and fetch_vars.
    """
    # remind users to set auc_states to 0 if auc op were found.
    for op in program.global_block().ops:
        # clear device of Op
        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
        op._set_attr(device_attr_name, "")
        if op.type == 'auc':
            warnings.warn("Be sure that you have set auc states to 0 "
                          "before saving inference model.")
            break

    # fix the bug that the activation op's output as target will be pruned.
    # will affect the inference performance.
    # TODO(Superjomn) add an IR pass to remove 1-scale op.
    with program_guard(program):
        uniq_fetch_vars = []
        for i, var in enumerate(fetch_vars):
            var = layers.scale(var,
                               1.,
                               name="save_infer_model/scale_{}".format(i))
            uniq_fetch_vars.append(var)
        fetch_vars = uniq_fetch_vars

    # serialize program
    copy_program = program.clone()
    global_block = copy_program.global_block()
    remove_op_idx = []
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            remove_op_idx.append(i)
    for idx in remove_op_idx[::-1]:
        global_block._remove_op(idx)
    copy_program.desc.flush()

    feed_var_names = [var.name for var in feed_vars]
    copy_program = copy_program._prune_with_input(
        feeded_var_names=feed_var_names, targets=fetch_vars)
    copy_program = copy_program._inference_optimize(prune_read_op=True)
    fetch_var_names = [var.name for var in fetch_vars]
    prepend_feed_ops(copy_program, feed_var_names)
    append_fetch_ops(copy_program, fetch_var_names)
    copy_program.desc._set_version()
    return copy_program
Ejemplo n.º 3
0
def normalize_program(program, feed_vars, fetch_vars):
    """
    :api_attr: Static Graph

    Normalize/Optimize a program according to feed_vars and fetch_vars.

    Args:
        program(Program): Specify a program you want to optimize.
        feed_vars(Variable | list[Variable]): Variables needed by inference.
        fetch_vars(Variable | list[Variable]): Variables returned by inference.

    Returns:
        Program: Normalized/Optimized program.

    Raises:
        TypeError: If `program` is not a Program, an exception is thrown.
        TypeError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown.
        TypeError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown.

    Examples:
        .. code-block:: python

            import paddle

            paddle.enable_static()

            path_prefix = "./infer_model"

            # User defined network, here a softmax regession example
            image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32')
            label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
            predict = paddle.static.nn.fc(image, 10, activation='softmax')

            loss = paddle.nn.functional.cross_entropy(predict, label)

            exe = paddle.static.Executor(paddle.CPUPlace())
            exe.run(paddle.static.default_startup_program())

            # normalize main program.
            program = paddle.static.default_main_program()
            normalized_program = paddle.static.normalize_program(program, [image], [predict])

    """
    if not isinstance(program, Program):
        raise TypeError(
            "program type must be `fluid.Program`, but received `%s`" %
            type(program))
    if not isinstance(feed_vars, list):
        feed_vars = [feed_vars]
    if not all(isinstance(v, Variable) for v in feed_vars):
        raise TypeError(
            "feed_vars type must be a Variable or a list of Variable.")
    if not isinstance(fetch_vars, list):
        fetch_vars = [fetch_vars]
    if not all(isinstance(v, Variable) for v in fetch_vars):
        raise TypeError(
            "fetch_vars type must be a Variable or a list of Variable.")

    # remind users to set auc_states to 0 if auc op were found.
    for op in program.global_block().ops:
        # clear device of Op
        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
        op._set_attr(device_attr_name, "")
        if op.type == 'auc':
            warnings.warn("Be sure that you have set auc states to 0 "
                          "before saving inference model.")
            break

    # fix the bug that the activation op's output as target will be pruned.
    # will affect the inference performance.
    # TODO(Superjomn) add an IR pass to remove 1-scale op.
    with program_guard(program):
        uniq_fetch_vars = []
        for i, var in enumerate(fetch_vars):
            if var.dtype != paddle.bool:
                var = layers.scale(
                    var, 1., name="save_infer_model/scale_{}".format(i))
            uniq_fetch_vars.append(var)
        fetch_vars = uniq_fetch_vars

    # serialize program
    copy_program = program.clone()
    global_block = copy_program.global_block()
    remove_op_idx = []
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            remove_op_idx.append(i)
    for idx in remove_op_idx[::-1]:
        global_block._remove_op(idx)
    copy_program.desc.flush()

    feed_var_names = [var.name for var in feed_vars]
    copy_program = copy_program._prune_with_input(
        feeded_var_names=feed_var_names, targets=fetch_vars)
    copy_program = copy_program._inference_optimize(prune_read_op=True)
    fetch_var_names = [var.name for var in fetch_vars]
    prepend_feed_ops(copy_program, feed_var_names)
    append_fetch_ops(copy_program, fetch_var_names)
    copy_program.desc._set_version()
    return copy_program