def from_func_spec(func_spec, input_spec, class_instance): """ Builds the main_program with specialized inputs and returns outputs of program as fetch_list. Args: func_spec(FunctionSpec): A FunctionSpec instance for decorated function. input_spec(list[InputSpec]): """ # Transforms dygraph function into static function and caches it. dygraph_function = func_spec.dygraph_function static_func = convert_to_static(dygraph_function) main_program, startup_program = framework.Program(), framework.Program() # Note: The random seed should be synchronized into cached program # if set in `fluid.dygraph_guard` because some ops rely on it, such as # `fluid.layers.dropout`. main_program.random_seed = framework.default_main_program().random_seed startup_program.random_seed = framework.default_startup_program( ).random_seed with framework.program_guard(main_program, startup_program): with _switch_declarative_mode_guard_(is_declarative=True): # 1. Adds `fluid.data` layers for input if needed inputs = func_spec.to_static_inputs_with_spec(input_spec, main_program) if class_instance: inputs = tuple([class_instance] + list(inputs)) # 2. Gets all ParamBases and buffered VarBases in the function all_parameters_and_buffers = list( get_parameters(class_instance).values()) + list( get_buffers(class_instance).values()) # 3. Builds program only once and returns the output Variables. with param_guard(get_parameters( class_instance, False)), param_guard( get_buffers(class_instance, False)): try: outputs = static_func(*inputs) except BaseException as e: # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. attach_error_data(e) raise if not isinstance(outputs, (tuple, list)) and outputs is not None: outputs = [outputs] main_program = update_op_callstack_with_origin_info(main_program) return ConcreteProgram( inputs=inputs, outputs=outputs, parameters=all_parameters_and_buffers, function=dygraph_function, main_program=main_program, startup_program=startup_program)
def from_func_spec(func_spec): """ Builds the main_program with specialized inputs and returns outputs of program as fetch_list. """ # Transforms dygraph function into static function and caches it. dygaph_function = func_spec.dyfunc static_func = convert_function_with_cache(dygaph_function) main_program, startup_program = framework.Program(), framework.Program( ) # Note: The random seed should be synchronized into cached program # if set in `fluid.dygrap_guard` because some ops rely on it, such as # `fluid.layers.dropout`. main_program.random_seed = framework.default_main_program().random_seed startup_program.random_seed = framework.default_startup_program( ).random_seed with framework.program_guard(main_program, startup_program): # 1. Adds `fluid.data` layers for input if needed inputs = func_spec.to_static_inputs(main_program) # 2. Gets all ParamBases in the function all_parameters = list(func_spec.parameters().values()) # 3. Builds program only once and returns the output Variables. with param_guard(func_spec.parameters(False)): outputs = static_func(*inputs) if not isinstance(outputs, (tuple, list)): outputs = [outputs] if outputs else [] return ConcreteProgram(inputs=inputs, outputs=outputs, parameters=all_parameters, func=dygaph_function, main_program=main_program, startup_program=startup_program)
def _lower(block, reverse): # Some functions which are only used in _lower. def bind(args, to_bind, value_table): for i in range(len(args)): if isinstance(args[i], list): bind(args[i], to_bind, value_table) elif args[i] is not None and args[i].name in to_bind: args[i] = value_table[to_bind[args[i].name]] def bind_name(names, to_bind): return_list = [] for name in names: if isinstance(name, list): return_list.append(bind_name(name, to_bind)) else: return_list.append(to_bind[name] if name in to_bind else name) return return_list def expand_nested_list(xs): return_list = [] for x in xs: if isinstance(x, list): return_list = return_list + expand_nested_list(x) else: return_list.append(x) return return_list # Step1: Do some preparatory work for lower lower_fn = _prim2orig if reverse else _orig2prim lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim value_table = {} to_bind = {} to_bind_rev = {} for var in block.desc.all_vars(): value_table[var.name()] = block.var(var.name()) ops_to_remove = [] vars_to_remove = set() # Step2: Process all ops in the target block for op_idx in range(len(block.ops)): op = block.ops[op_idx] ops_to_remove.append(op_idx) if lookup_fn(op.type) is not None: input_args = get_input_var_list(op) bind(input_args, to_bind, value_table) for orig_out, new_out in zip( expand_nested_list(get_output_var_list(op)), expand_nested_list(to_tensors(lower_fn(op, *input_args)))): assert not (orig_out is None) ^ ( new_out is None), "orig_out and new_out should match." vars_to_remove.add(new_out.name) value_table[new_out.name] = new_out to_bind[orig_out.name] = new_out.name to_bind_rev[new_out.name] = orig_out.name else: inputs = {} for i in range(len(op.input_names)): inputs[op.input_names[i]] = bind_name( op.input(op.input_names[i]), to_bind) outputs = {} for i in range(len(op.output_names)): outputs[op.output_names[i]] = op.output(op.output_names[i]) attrs = {} for name in sorted(op.attr_names): attrs[name] = op.attr(name) from paddle.fluid.dygraph.base import param_guard new_op_desc = block.desc.append_op() with param_guard(inputs), param_guard(outputs): op = Operator(block=block, desc=new_op_desc, type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) block.ops.append(op) # Step3: Do some post-processing work for op_idx in reversed(ops_to_remove): block.desc._remove_op(op_idx, op_idx + 1) del block.ops[op_idx] block._sync_with_cpp() for op_idx in range(len(block.ops)): op = block.ops[op_idx] for in_name in op.input_arg_names: if in_name in to_bind_rev: op._rename_input(in_name, to_bind_rev[in_name]) for out_name in op.output_arg_names: if out_name in to_bind_rev: op._rename_output(out_name, to_bind_rev[out_name]) for var_name in sorted(vars_to_remove): assert var_name in to_bind_rev, 'var_name "{}" is not in to_bind_rev.'.format( var_name) if var_name != to_bind_rev[var_name]: block.desc._remove_var(cpt.to_bytes(var_name)) del block.vars[var_name] block._sync_with_cpp()
def from_func_spec(func_spec, input_spec, input_kwargs_spec, class_instance, **kwargs): """ Builds the main_program with specialized inputs and returns outputs of program as fetch_list. Args: func_spec(FunctionSpec): A FunctionSpec instance for decorated function. input_spec(list[InputSpec]): """ # verify the instance is initialized in imperative mode. _verify_init_in_dynamic_mode(class_instance) # Transforms dygraph function into static function and caches it. dygraph_function = func_spec.dygraph_function static_func = convert_to_static(dygraph_function) # apply pre\post hook for outermost layer hook_helper = HookHelper(dygraph_function, class_instance, kwargs.get("with_hook", False)) main_program, startup_program = framework.Program(), framework.Program( ) # Note: The random seed should be synchronized into cached program # if set in `fluid.dygraph_guard` because some ops rely on it, such as # `fluid.layers.dropout`. main_program.random_seed = framework.default_main_program().random_seed startup_program.random_seed = framework.default_startup_program( ).random_seed from paddle.fluid.dygraph.base import _switch_declarative_mode_guard_ with framework.program_guard(main_program, startup_program): with _switch_declarative_mode_guard_(is_declarative=True): # 1. Adds `fluid.data` layers for input if needed static_inputs = func_spec.to_static_inputs_with_spec( input_spec, main_program) _kwargs = func_spec.to_static_inputs_with_spec( input_kwargs_spec, main_program) if class_instance: static_inputs = tuple([class_instance] + list(static_inputs)) # 2. Gets all ParamBases and buffered VarBases in the function all_parameters_and_buffers = _extract_indeed_params_buffers( class_instance) # 3. Builds program only once and returns the output Variables. with param_guard(get_parameters( class_instance, False)), param_guard( get_buffers(class_instance, False)): try: # only for jit.save, do nothing while train and eval process inputs = hook_helper.apply_pre_hooks(static_inputs) if _kwargs: outputs = static_func(*inputs, **_kwargs) else: outputs = static_func(*inputs) outputs = hook_helper.apply_post_hooks(inputs, outputs) except BaseException as e: # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. error.attach_error_data(e) error_data = getattr(e, error.ERROR_DATA, None) if error_data: error_data.raise_new_exception() raise if outputs is not None: need_wrap_into_list = not isinstance( outputs, (tuple, list)) or len(outputs) == 1 if need_wrap_into_list: outputs = [outputs] main_program = update_op_callstack_with_origin_info(main_program) return ConcreteProgram(inputs=static_inputs, outputs=outputs, parameters=all_parameters_and_buffers, function=dygraph_function, main_program=main_program, startup_program=startup_program, **kwargs)