def from_func_spec(func_spec, input_spec, class_instance): """ Builds the main_program with specialized inputs and returns outputs of program as fetch_list. Args: func_spec(FunctionSpec): A FunctionSpec instance for decorated function. input_spec(list[InputSpec]): """ # Transforms dygraph function into static function and caches it. dygraph_function = func_spec.dygraph_function static_func = convert_to_static(dygraph_function) main_program, startup_program = framework.Program(), framework.Program() # Note: The random seed should be synchronized into cached program # if set in `fluid.dygraph_guard` because some ops rely on it, such as # `fluid.layers.dropout`. main_program.random_seed = framework.default_main_program().random_seed startup_program.random_seed = framework.default_startup_program( ).random_seed with framework.program_guard(main_program, startup_program): with _switch_declarative_mode_guard_(is_declarative=True): # 1. Adds `fluid.data` layers for input if needed inputs = func_spec.to_static_inputs_with_spec(input_spec, main_program) if class_instance: inputs = tuple([class_instance] + list(inputs)) # 2. Gets all ParamBases and buffered VarBases in the function all_parameters_and_buffers = list( get_parameters(class_instance).values()) + list( get_buffers(class_instance).values()) # 3. Builds program only once and returns the output Variables. with param_guard(get_parameters( class_instance, False)), param_guard( get_buffers(class_instance, False)): try: outputs = static_func(*inputs) except BaseException as e: # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. attach_error_data(e) raise if not isinstance(outputs, (tuple, list)) and outputs is not None: outputs = [outputs] main_program = update_op_callstack_with_origin_info(main_program) return ConcreteProgram( inputs=inputs, outputs=outputs, parameters=all_parameters_and_buffers, function=dygraph_function, main_program=main_program, startup_program=startup_program)
def _extract_indeed_params_buffers(class_instance): """ To filter not initialzed buffers. """ params = list(get_parameters(class_instance).values()) buffers = list(get_buffers(class_instance).values()) buffers = [buffer for buffer in buffers if len(buffer.shape) != 0] return params + buffers
def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance, **kwargs): """ Builds the main_program with specialized inputs and returns outputs of program as fetch_list. Args: func_spec(FunctionSpec): A FunctionSpec instance for decorated function. input_spec(list[InputSpec]): """ # verify the instance is initialized in imperative mode. _verify_init_in_dynamic_mode(class_instance) # Transforms dygraph function into static function and caches it. dygraph_function = func_spec.dygraph_function static_func = convert_to_static(dygraph_function) # apply pre\post hook for outermost layer hook_helper = HookHelper(dygraph_function, class_instance, kwargs.get("with_hook", False)) main_program, startup_program = framework.Program(), framework.Program( ) # Note: The random seed should be synchronized into cached program # if set in `fluid.dygraph_guard` because some ops rely on it, such as # `fluid.layers.dropout`. main_program.random_seed = framework.default_main_program().random_seed startup_program.random_seed = framework.default_startup_program( ).random_seed from paddle.fluid.dygraph.base import _switch_declarative_mode_guard_ with framework.program_guard(main_program, startup_program): with _switch_declarative_mode_guard_(is_declarative=True): # 1. Adds `fluid.data` layers for input if needed static_inputs = func_spec.to_static_inputs_with_spec( input_spec, main_program) _kwargs = func_spec.to_static_inputs_with_spec( input_kwargs_spec, main_program) if class_instance: static_inputs = tuple([class_instance] + list(static_inputs)) # 2. Gets all ParamBases and buffered VarBases in the function all_parameters_and_buffers = _extract_indeed_params_buffers( class_instance) # 3. Builds program only once and returns the output Variables. with param_guard(get_parameters( class_instance, False)), param_guard( get_buffers(class_instance, False)): try: # only for jit.save, do nothing while train and eval process inputs = hook_helper.apply_pre_hooks(static_inputs) if _kwargs: outputs = static_func(*inputs, **_kwargs) else: outputs = static_func(*inputs) outputs = hook_helper.apply_post_hooks(inputs, outputs) except BaseException as e: # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. error.attach_error_data(e) error_data = getattr(e, error.ERROR_DATA, None) if error_data: error_data.raise_new_exception() raise if outputs is not None: need_wrap_into_list = not isinstance( outputs, (tuple, list)) or len(outputs) == 1 if need_wrap_into_list: outputs = [outputs] main_program = update_op_callstack_with_origin_info(main_program) return ConcreteProgram(inputs=static_inputs, outputs=outputs, parameters=all_parameters_and_buffers, function=dygraph_function, main_program=main_program, startup_program=startup_program, **kwargs)