Пример #1
0
assert args.config or args.meta, "Either config or metagraph must be present!"

with tf.Graph().as_default() as G:
    if args.config:
        MODEL = imp.load_source('config_script', args.config).Model
        M = MODEL()
        M.build_graph(M.get_input_vars())
    else:
        M = ModelFromMetaGraph(args.meta)

    # loading...
    if args.model.endswith('.npy'):
        init = sessinit.ParamRestore(np.load(args.model).item())
    else:
        init = sessinit.SaverRestore(args.model)
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    init.init(sess)

    # dump ...
    with sess.as_default():
        if args.output.endswith('npy'):
            varmanip.dump_session_params(args.output)
        else:
            var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            var.extend(tf.get_collection(EXTRA_SAVE_VARS_KEY))
            var_dict = {}
            for v in var:
                name = varmanip.get_savename_from_varname(v.name)
                var_dict[name] = v
            logger.info("Variables to dump:")
Пример #2
0
    def export(self,
               checkpoint,
               export_path,
               version=1,
               tags=[tf.saved_model.tag_constants.SERVING],
               signature_name='prediction_pipeline'):
        """Use SavedModelBuilder to export a trained model without TensorPack depency.

        Remarks:
            This produces
                variables/       # output from the vanilla Saver
                    variables.data-?????-of-?????
                    variables.index
                saved_model.pb   # saved model in protcol buffer format

            Currently, we only support a single signature, which is the general PredictSignatureDef:
            https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md

        Args:
            checkpoint (str): path to checkpoint file
            export_path (str): path for export directory
            tags (list): list of user specified tags
            signature_name (str): name of signature for prediction
        """
        logger.info('[export] build model for %s' % checkpoint)
        with TowerContext('', is_training=False):
            self.model._build_graph(self.placehdrs)

            self.sess = tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True))
            # load values from latest checkpoint
            init = sessinit.SaverRestore(checkpoint)
            self.sess.run(tf.global_variables_initializer())
            init.init(self.sess)

            self.inputs = []
            for n in self.input_names:
                tensor = tf.get_default_graph().get_tensor_by_name('%s:0' % n)
                logger.info('[export] add input-tensor "%s"' % tensor.name)
                self.inputs.append(tensor)

            self.outputs = []
            for n in self.output_names:
                tensor = tf.get_default_graph().get_tensor_by_name('%s:0' % n)
                logger.info('[export] add output-tensor "%s"' % tensor.name)
                self.outputs.append(tensor)

            logger.info('[export] exporting trained model to %s' % export_path)
            builder = tf.saved_model.builder.SavedModelBuilder(export_path)

            logger.info('[export] build signatures')
            # build inputs
            inputs_signature = dict()
            for n, v in zip(self.input_names, self.inputs):
                logger.info('[export] add input signature: %s' % v)
                inputs_signature[n] = tf.saved_model.utils.build_tensor_info(v)

            outputs_signature = dict()
            for n, v in zip(self.output_names, self.outputs):
                logger.info('[export] add output signature: %s' % v)
                outputs_signature[n] = tf.saved_model.utils.build_tensor_info(
                    v)

            prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
                inputs=inputs_signature,
                outputs=outputs_signature,
                method_name=tf.saved_model.signature_constants.
                PREDICT_METHOD_NAME)

            # legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

            builder.add_meta_graph_and_variables(
                self.sess,
                tags,
                signature_def_map={signature_name: prediction_signature})
            builder.save()