def _normalize_program(self, program, feed_vars, fetch_vars): if not isinstance(program, Program): raise TypeError( "program type must be `fluid.Program`, but received `%s`" % type(program)) if not isinstance(feed_vars, list): feed_vars = [feed_vars] if not all(isinstance(v, Variable) for v in feed_vars): raise TypeError( "feed_vars type must be a Variable or a list of Variable.") if not isinstance(fetch_vars, list): fetch_vars = [fetch_vars] if not all(isinstance(v, Variable) for v in fetch_vars): raise TypeError( "fetch_vars type must be a Variable or a list of Variable.") # remind users to set auc_states to 0 if auc op were found. for op in program.global_block().ops: # clear device of Op device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName( ) op._set_attr(device_attr_name, "") if op.type == 'auc': warnings.warn("Be sure that you have set auc states to 0 " "before saving inference model.") break # serialize program copy_program = program.clone() global_block = copy_program.global_block() remove_op_idx = [] for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": remove_op_idx.append(i) for idx in remove_op_idx[::-1]: global_block._remove_op(idx) copy_program.desc.flush() feed_var_names = [var.name for var in feed_vars] copy_program = copy_program._prune_with_input( feeded_var_names=feed_var_names, targets=fetch_vars) copy_program = copy_program._inference_optimize(prune_read_op=True) fetch_var_names = [var.name for var in fetch_vars] prepend_feed_ops(copy_program, feed_var_names) append_fetch_ops(copy_program, fetch_var_names) copy_program.desc._set_version() return copy_program
def _normalize_program(program, feed_vars, fetch_vars): """ optimize program according feed_vars and fetch_vars. """ # remind users to set auc_states to 0 if auc op were found. for op in program.global_block().ops: # clear device of Op device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() op._set_attr(device_attr_name, "") if op.type == 'auc': warnings.warn("Be sure that you have set auc states to 0 " "before saving inference model.") break # fix the bug that the activation op's output as target will be pruned. # will affect the inference performance. # TODO(Superjomn) add an IR pass to remove 1-scale op. with program_guard(program): uniq_fetch_vars = [] for i, var in enumerate(fetch_vars): var = layers.scale(var, 1., name="save_infer_model/scale_{}".format(i)) uniq_fetch_vars.append(var) fetch_vars = uniq_fetch_vars # serialize program copy_program = program.clone() global_block = copy_program.global_block() remove_op_idx = [] for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": remove_op_idx.append(i) for idx in remove_op_idx[::-1]: global_block._remove_op(idx) copy_program.desc.flush() feed_var_names = [var.name for var in feed_vars] copy_program = copy_program._prune_with_input( feeded_var_names=feed_var_names, targets=fetch_vars) copy_program = copy_program._inference_optimize(prune_read_op=True) fetch_var_names = [var.name for var in fetch_vars] prepend_feed_ops(copy_program, feed_var_names) append_fetch_ops(copy_program, fetch_var_names) copy_program.desc._set_version() return copy_program
def normalize_program(program, feed_vars, fetch_vars): """ :api_attr: Static Graph Normalize/Optimize a program according to feed_vars and fetch_vars. Args: program(Program): Specify a program you want to optimize. feed_vars(Variable | list[Variable]): Variables needed by inference. fetch_vars(Variable | list[Variable]): Variables returned by inference. Returns: Program: Normalized/Optimized program. Raises: TypeError: If `program` is not a Program, an exception is thrown. TypeError: If `feed_vars` is not a Variable or a list of Variable, an exception is thrown. TypeError: If `fetch_vars` is not a Variable or a list of Variable, an exception is thrown. Examples: .. code-block:: python import paddle paddle.enable_static() path_prefix = "./infer_model" # User defined network, here a softmax regession example image = paddle.static.data(name='img', shape=[None, 28, 28], dtype='float32') label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') predict = paddle.static.nn.fc(image, 10, activation='softmax') loss = paddle.nn.functional.cross_entropy(predict, label) exe = paddle.static.Executor(paddle.CPUPlace()) exe.run(paddle.static.default_startup_program()) # normalize main program. program = paddle.static.default_main_program() normalized_program = paddle.static.normalize_program(program, [image], [predict]) """ if not isinstance(program, Program): raise TypeError( "program type must be `fluid.Program`, but received `%s`" % type(program)) if not isinstance(feed_vars, list): feed_vars = [feed_vars] if not all(isinstance(v, Variable) for v in feed_vars): raise TypeError( "feed_vars type must be a Variable or a list of Variable.") if not isinstance(fetch_vars, list): fetch_vars = [fetch_vars] if not all(isinstance(v, Variable) for v in fetch_vars): raise TypeError( "fetch_vars type must be a Variable or a list of Variable.") # remind users to set auc_states to 0 if auc op were found. for op in program.global_block().ops: # clear device of Op device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() op._set_attr(device_attr_name, "") if op.type == 'auc': warnings.warn("Be sure that you have set auc states to 0 " "before saving inference model.") break # fix the bug that the activation op's output as target will be pruned. # will affect the inference performance. # TODO(Superjomn) add an IR pass to remove 1-scale op. with program_guard(program): uniq_fetch_vars = [] for i, var in enumerate(fetch_vars): if var.dtype != paddle.bool: var = layers.scale( var, 1., name="save_infer_model/scale_{}".format(i)) uniq_fetch_vars.append(var) fetch_vars = uniq_fetch_vars # serialize program copy_program = program.clone() global_block = copy_program.global_block() remove_op_idx = [] for i, op in enumerate(global_block.ops): op.desc.set_is_target(False) if op.type == "feed" or op.type == "fetch": remove_op_idx.append(i) for idx in remove_op_idx[::-1]: global_block._remove_op(idx) copy_program.desc.flush() feed_var_names = [var.name for var in feed_vars] copy_program = copy_program._prune_with_input( feeded_var_names=feed_var_names, targets=fetch_vars) copy_program = copy_program._inference_optimize(prune_read_op=True) fetch_var_names = [var.name for var in fetch_vars] prepend_feed_ops(copy_program, feed_var_names) append_fetch_ops(copy_program, fetch_var_names) copy_program.desc._set_version() return copy_program