예제 #1
0
 def _run(self, to_static):
     with fluid.dygraph.guard(self.place):
         if to_static:
             ret = declarative(self.dyfunc)(self.len)
         else:
             ret = self.dyfunc(self.len)
         return ret.numpy()
예제 #2
0
 def _run(self, to_static):
     with fluid.dygraph.guard():
         if to_static:
             res = declarative(self.dygraph_func)(self.input).numpy()
         else:
             res = self.dygraph_func(self.input).numpy()
         return res
예제 #3
0
    def test_output(self):
        with fluid.dygraph.guard():
            ret = sum_even_util_limit(80, 10)
            self.assertEqual(ret.numpy(), 30)

            ret = declarative(sum_under_while)(100)
            self.assertEqual(ret.numpy(), 5050)
예제 #4
0
    def train(self, to_static=False):

        with fluid.dygraph.guard():
            if to_static:
                res = declarative(self.dygraph_func)(self.input, self.iter_num)
            else:
                res = self.dygraph_func(self.input, self.iter_num)
            return self.varbase_to_numpy(res)
예제 #5
0
 def _run(self, to_static):
     with fluid.dygraph.guard(self.place):
         # Set the input of dyfunc to VarBase
         tensor_x = fluid.dygraph.to_variable(self.x, zero_copy=False)
         if to_static:
             ret = declarative(self.dyfunc)(tensor_x)
         else:
             ret = self.dyfunc(tensor_x)
         return ret.numpy()
예제 #6
0
    def _run_dygraph(self, to_static=False):

        with fluid.dygraph.guard(place):
            x_v = fluid.dygraph.to_variable(self.x)
            if to_static:
                ret = declarative(self.dyfunc)(x_v)
            else:
                ret = self.dyfunc(x_v)
            return ret.numpy()
예제 #7
0
 def run_static_mode(self):
     with fluid.dygraph.guard():
         res = declarative(self.dygraph_func)(self.input)
         return res.numpy()
예제 #8
0
def get_program(layer, input_spec, output_spec, **configs):
    paddle.jit.set_verbosity(0)
    prog_translator = ProgramTranslator()
    if not prog_translator.enable_to_static:
        raise RuntimeError(
            "The Paddle2onnx doesn't work when setting ProgramTranslator.enable to False."
        )

    if not isinstance(layer, Layer):
        raise TypeError(
            "The input of paddle2onnx should be 'Layer', but received input type is %s."
            % type(layer))

    if isinstance(layer, paddle.DataParallel):
        inner_layer = layer._layers
    else:
        inner_layer = layer

    # avoid change user given input_spec
    inner_input_spec = None
    if input_spec is not None:
        for attr_func in dir(inner_layer):
            static_func = getattr(inner_layer, attr_func, None)
            if isinstance(static_func,
                          StaticFunction) and 'forward' != attr_func:
                raise ValueError(
                    "If there are static functions other than 'forward' that need to be saved, the input 'input_spec' should be None, but received the type of 'input_spec' is %s."
                    % type(input_spec))

        if not isinstance(input_spec, (list, tuple)):
            raise TypeError(
                "The input input_spec should be 'list', but received input_spec's type is %s."
                % type(input_spec))
        inner_input_spec = []
        for var in flatten(input_spec):
            if isinstance(var, paddle.static.InputSpec):
                inner_input_spec.append(var)
            elif isinstance(var, (core.VarBase, core.eager.Tensor, Variable)):
                inner_input_spec.append(
                    paddle.static.InputSpec.from_tensor(var))
            else:
                # NOTE(Aurelius84): Support non-Tensor type in `input_spec`.
                inner_input_spec.append(var)

    extra_var_info = dict()
    functions = dir(inner_layer)
    for attr_func in functions:
        static_func = getattr(inner_layer, attr_func, None)
        if isinstance(static_func, StaticFunction):
            concrete_program = static_func.concrete_program_specify_input_spec(
                inner_input_spec)
        elif 'forward' == attr_func:
            # transform in jit.save, if input_spec is incomplete, declarative will throw error
            # inner_input_spec is list[InputSpec], it should be packed with same structure
            # as original input_spec here.
            if inner_input_spec:
                inner_input_spec = pack_sequence_as(input_spec,
                                                    inner_input_spec)
            static_forward = declarative(inner_layer.forward,
                                         input_spec=inner_input_spec)
            concrete_program = static_forward.concrete_program
            # the input_spec has been used in declarative, which is equal to
            # @declarative with input_spec and jit.save without input_spec,
            # avoid needless warning
            inner_input_spec = None
        else:
            continue

        input_var_names = _get_input_var_names(concrete_program.inputs,
                                               inner_input_spec)

        # NOTE(chenweihang): [ Get output variables ]
        # the rule is like [ Get input variables name ]. For output var,
        # we only support VarBase spec, and actually, we only need the
        # var name of output, and we don't recommended to use output_spec
        output_vars = _get_output_vars(concrete_program.outputs, output_spec)

    feeded_var_names = input_var_names
    target_vars = output_vars
    main_program = concrete_program.main_program.clone()
    export_for_deployment = True

    if isinstance(feeded_var_names, six.string_types):
        feeded_var_names = [feeded_var_names]
    elif export_for_deployment:
        if len(feeded_var_names) > 0:
            # TODO(paddle-dev): polish these code blocks
            if not (bool(feeded_var_names) and all(
                    isinstance(name, six.string_types)
                    for name in feeded_var_names)):
                raise ValueError("'feed_var_names' should be a list of str.")

    if isinstance(target_vars, Variable):
        target_vars = [target_vars]
    elif export_for_deployment:
        if not (bool(target_vars)
                and all(isinstance(var, Variable) for var in target_vars)):
            raise ValueError("'target_vars' should be a list of Variable.")

    main_program = _get_valid_program(main_program)

    # remind user to set auc_states to zeros if the program contains auc op
    all_ops = main_program.global_block().ops
    for op in all_ops:
        # clear device of Op
        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
        op._set_attr(device_attr_name, "")
        if op.type == 'auc':
            warnings.warn(
                "please ensure that you have set the auc states to zeros before saving inference model"
            )
            break

    with program_guard(main_program):
        uniq_target_vars = []
        for i, var in enumerate(target_vars):
            uniq_target_vars.append(var)
        target_vars = uniq_target_vars
    target_var_name_list = [var.name for var in target_vars]

    origin_program = main_program.clone()

    main_program = main_program.clone()
    global_block = main_program.global_block()
    need_to_remove_op_index = []
    for i, op in enumerate(global_block.ops):
        op.desc.set_is_target(False)
        if op.type == "feed" or op.type == "fetch":
            need_to_remove_op_index.append(i)

    for index in need_to_remove_op_index[::-1]:
        global_block._remove_op(index)

    main_program.desc.flush()

    main_program = main_program._prune_with_input(
        feeded_var_names=feeded_var_names, targets=target_vars)
    main_program = main_program._inference_optimize(prune_read_op=True)
    fetch_var_names = [v.name for v in target_vars]

    for target_v in target_vars:
        if not main_program.global_block().has_var(target_v.name):
            main_program.global_block().create_var(
                name=target_v.name,
                shape=target_v.shape,
                dtype=target_v.dtype,
                persistable=target_v.persistable)

    prepend_feed_ops(main_program, feeded_var_names)
    append_fetch_ops(main_program, fetch_var_names)

    main_program.desc._set_version()
    paddle.fluid.core.save_op_version_info(main_program.desc)

    main_program._copy_dist_param_info_from(origin_program)

    return main_program, feeded_var_names, target_vars