コード例 #1
0
    def from_func_spec(func_spec, input_spec, class_instance):
        """
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.

        Args:
            func_spec(FunctionSpec): A FunctionSpec instance for decorated function.
            input_spec(list[InputSpec]): 
        """
        # Transforms dygraph function into static function and caches it.
        dygraph_function = func_spec.dygraph_function
        static_func = convert_to_static(dygraph_function)

        main_program, startup_program = framework.Program(), framework.Program()
        # Note: The random seed should be synchronized into cached program
        # if set in `fluid.dygraph_guard` because some ops rely on it, such as
        # `fluid.layers.dropout`.
        main_program.random_seed = framework.default_main_program().random_seed
        startup_program.random_seed = framework.default_startup_program(
        ).random_seed

        with framework.program_guard(main_program, startup_program):
            with _switch_declarative_mode_guard_(is_declarative=True):
                # 1. Adds `fluid.data` layers for input if needed
                inputs = func_spec.to_static_inputs_with_spec(input_spec,
                                                              main_program)
                if class_instance:
                    inputs = tuple([class_instance] + list(inputs))

                # 2. Gets all ParamBases and buffered VarBases in the function
                all_parameters_and_buffers = list(
                    get_parameters(class_instance).values()) + list(
                        get_buffers(class_instance).values())

                # 3. Builds program only once and returns the output Variables.
                with param_guard(get_parameters(
                        class_instance, False)), param_guard(
                            get_buffers(class_instance, False)):
                    try:
                        outputs = static_func(*inputs)
                    except BaseException as e:
                        # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
                        attach_error_data(e)
                        raise

                if not isinstance(outputs,
                                  (tuple, list)) and outputs is not None:
                    outputs = [outputs]

        main_program = update_op_callstack_with_origin_info(main_program)

        return ConcreteProgram(
            inputs=inputs,
            outputs=outputs,
            parameters=all_parameters_and_buffers,
            function=dygraph_function,
            main_program=main_program,
            startup_program=startup_program)
コード例 #2
0
    def from_func_spec(func_spec):
        """
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.
        """
        # Transforms dygraph function into static function and caches it.
        dygaph_function = func_spec.dyfunc
        static_func = convert_function_with_cache(dygaph_function)

        main_program, startup_program = framework.Program(), framework.Program(
        )
        # Note: The random seed should be synchronized into cached program
        # if set in `fluid.dygrap_guard` because some ops rely on it, such as
        # `fluid.layers.dropout`.
        main_program.random_seed = framework.default_main_program().random_seed
        startup_program.random_seed = framework.default_startup_program(
        ).random_seed

        with framework.program_guard(main_program, startup_program):
            # 1. Adds `fluid.data` layers for input if needed
            inputs = func_spec.to_static_inputs(main_program)

            # 2. Gets all ParamBases in the function
            all_parameters = list(func_spec.parameters().values())

            # 3. Builds program only once and returns the output Variables.
            with param_guard(func_spec.parameters(False)):
                outputs = static_func(*inputs)
            if not isinstance(outputs, (tuple, list)):
                outputs = [outputs] if outputs else []

        return ConcreteProgram(inputs=inputs,
                               outputs=outputs,
                               parameters=all_parameters,
                               func=dygaph_function,
                               main_program=main_program,
                               startup_program=startup_program)
コード例 #3
0
ファイル: primx.py プロジェクト: sandyhouse/Paddle
def _lower(block, reverse):
    # Some functions which are only used in _lower.
    def bind(args, to_bind, value_table):
        for i in range(len(args)):
            if isinstance(args[i], list):
                bind(args[i], to_bind, value_table)
            elif args[i] is not None and args[i].name in to_bind:
                args[i] = value_table[to_bind[args[i].name]]

    def bind_name(names, to_bind):
        return_list = []
        for name in names:
            if isinstance(name, list):
                return_list.append(bind_name(name, to_bind))
            else:
                return_list.append(to_bind[name] if name in to_bind else name)
        return return_list

    def expand_nested_list(xs):
        return_list = []
        for x in xs:
            if isinstance(x, list):
                return_list = return_list + expand_nested_list(x)
            else:
                return_list.append(x)
        return return_list

    # Step1: Do some preparatory work for lower
    lower_fn = _prim2orig if reverse else _orig2prim
    lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim

    value_table = {}
    to_bind = {}
    to_bind_rev = {}
    for var in block.desc.all_vars():
        value_table[var.name()] = block.var(var.name())

    ops_to_remove = []
    vars_to_remove = set()

    # Step2: Process all ops in the target block
    for op_idx in range(len(block.ops)):
        op = block.ops[op_idx]
        ops_to_remove.append(op_idx)
        if lookup_fn(op.type) is not None:
            input_args = get_input_var_list(op)
            bind(input_args, to_bind, value_table)

            for orig_out, new_out in zip(
                    expand_nested_list(get_output_var_list(op)),
                    expand_nested_list(to_tensors(lower_fn(op, *input_args)))):
                assert not (orig_out is None) ^ (
                    new_out is None), "orig_out and new_out should match."
                vars_to_remove.add(new_out.name)
                value_table[new_out.name] = new_out
                to_bind[orig_out.name] = new_out.name
                to_bind_rev[new_out.name] = orig_out.name
        else:
            inputs = {}
            for i in range(len(op.input_names)):
                inputs[op.input_names[i]] = bind_name(
                    op.input(op.input_names[i]), to_bind)

            outputs = {}
            for i in range(len(op.output_names)):
                outputs[op.output_names[i]] = op.output(op.output_names[i])

            attrs = {}
            for name in sorted(op.attr_names):
                attrs[name] = op.attr(name)
            from paddle.fluid.dygraph.base import param_guard
            new_op_desc = block.desc.append_op()
            with param_guard(inputs), param_guard(outputs):
                op = Operator(block=block,
                              desc=new_op_desc,
                              type=op.type,
                              inputs=inputs,
                              outputs=outputs,
                              attrs=attrs)
            block.ops.append(op)

    # Step3: Do some post-processing work
    for op_idx in reversed(ops_to_remove):
        block.desc._remove_op(op_idx, op_idx + 1)
        del block.ops[op_idx]
    block._sync_with_cpp()

    for op_idx in range(len(block.ops)):
        op = block.ops[op_idx]
        for in_name in op.input_arg_names:
            if in_name in to_bind_rev:
                op._rename_input(in_name, to_bind_rev[in_name])

        for out_name in op.output_arg_names:
            if out_name in to_bind_rev:
                op._rename_output(out_name, to_bind_rev[out_name])

    for var_name in sorted(vars_to_remove):
        assert var_name in to_bind_rev, 'var_name "{}" is not in to_bind_rev.'.format(
            var_name)
        if var_name != to_bind_rev[var_name]:
            block.desc._remove_var(cpt.to_bytes(var_name))
            del block.vars[var_name]
    block._sync_with_cpp()
コード例 #4
0
    def from_func_spec(func_spec, input_spec, input_kwargs_spec,
                       class_instance, **kwargs):
        """
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.

        Args:
            func_spec(FunctionSpec): A FunctionSpec instance for decorated function.
            input_spec(list[InputSpec]): 
        """
        # verify the instance is initialized in imperative mode.
        _verify_init_in_dynamic_mode(class_instance)

        # Transforms dygraph function into static function and caches it.
        dygraph_function = func_spec.dygraph_function
        static_func = convert_to_static(dygraph_function)
        # apply pre\post hook for outermost layer
        hook_helper = HookHelper(dygraph_function, class_instance,
                                 kwargs.get("with_hook", False))

        main_program, startup_program = framework.Program(), framework.Program(
        )
        # Note: The random seed should be synchronized into cached program
        # if set in `fluid.dygraph_guard` because some ops rely on it, such as
        # `fluid.layers.dropout`.
        main_program.random_seed = framework.default_main_program().random_seed
        startup_program.random_seed = framework.default_startup_program(
        ).random_seed

        from paddle.fluid.dygraph.base import _switch_declarative_mode_guard_
        with framework.program_guard(main_program, startup_program):
            with _switch_declarative_mode_guard_(is_declarative=True):
                # 1. Adds `fluid.data` layers for input if needed
                static_inputs = func_spec.to_static_inputs_with_spec(
                    input_spec, main_program)
                _kwargs = func_spec.to_static_inputs_with_spec(
                    input_kwargs_spec, main_program)
                if class_instance:
                    static_inputs = tuple([class_instance] +
                                          list(static_inputs))

                # 2. Gets all ParamBases and buffered VarBases in the function
                all_parameters_and_buffers = _extract_indeed_params_buffers(
                    class_instance)

                # 3. Builds program only once and returns the output Variables.
                with param_guard(get_parameters(
                        class_instance, False)), param_guard(
                            get_buffers(class_instance, False)):
                    try:
                        # only for jit.save, do nothing while train and eval process
                        inputs = hook_helper.apply_pre_hooks(static_inputs)
                        if _kwargs:
                            outputs = static_func(*inputs, **_kwargs)
                        else:
                            outputs = static_func(*inputs)
                        outputs = hook_helper.apply_post_hooks(inputs, outputs)
                    except BaseException as e:
                        # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
                        error.attach_error_data(e)
                        error_data = getattr(e, error.ERROR_DATA, None)
                        if error_data:
                            error_data.raise_new_exception()
                        raise

                if outputs is not None:
                    need_wrap_into_list = not isinstance(
                        outputs, (tuple, list)) or len(outputs) == 1
                    if need_wrap_into_list:
                        outputs = [outputs]

        main_program = update_op_callstack_with_origin_info(main_program)

        return ConcreteProgram(inputs=static_inputs,
                               outputs=outputs,
                               parameters=all_parameters_and_buffers,
                               function=dygraph_function,
                               main_program=main_program,
                               startup_program=startup_program,
                               **kwargs)