コード例 #1
0
def main(args):
    with GFile(args.frozen_model_path, "rb") as f:
        graph_def = GraphDef()
        graph_def.ParseFromString(f.read())

    if os.path.exists(args.output_model_dir) and os.path.isdir(
            args.output_model_dir):
        shutil.rmtree(args.output_model_dir)

    with Session() as sess:
        # Then, we import the graph_def into a new Graph and returns it
        with Graph().as_default() as graph:
            import_graph_def(graph_def, name='')
            signature = predict_signature_def(
                inputs={
                    'image_batch': graph.get_tensor_by_name('image_batch:0'),
                    'phase_train': graph.get_tensor_by_name('phase_train:0')
                },
                outputs={
                    'embeddings': graph.get_tensor_by_name('embeddings:0')
                })

            builder = saved_model_builder.SavedModelBuilder(
                args.output_model_dir)
            builder.add_meta_graph_and_variables(
                sess=sess,
                tags=[tag_constants.SERVING],
                signature_def_map={'serving_default': signature})
            builder.save()
コード例 #2
0
def get_model(framework, model_variant):
    """
    Load the desired EfficientPose model variant using the requested deep learning framework.
    
    Args:
        framework: string
            Deep learning framework to use (Keras, TensorFlow, TensorFlow Lite or PyTorch)
        model_variant: string
            EfficientPose model to utilize (RT, I, II, III, IV, RT_Lite, I_Lite or II_Lite)
            
    Returns:
        Initialized EfficientPose model and corresponding resolution.
    """
    
    # Keras
    if framework in ['keras', 'k']:
        from tensorflow.keras.backend import set_learning_phase
        from tensorflow.keras.models import load_model
        set_learning_phase(0)
        model = load_model(join('models', 'keras', 'EfficientPose{0}.h5'.format(model_variant.upper())), custom_objects={'BilinearWeights': helpers.keras_BilinearWeights, 'Swish': helpers.Swish(helpers.eswish), 'eswish': helpers.eswish, 'swish1': helpers.swish1})
    
    # TensorFlow
    elif framework in ['tensorflow', 'tf']:
        from tensorflow.python.platform.gfile import FastGFile
        from tensorflow.compat.v1 import GraphDef
        from tensorflow.compat.v1.keras.backend import get_session
        from tensorflow import import_graph_def
        f = FastGFile(join('models', 'tensorflow', 'EfficientPose{0}.pb'.format(model_variant.upper())), 'rb')
        graph_def = GraphDef()
        graph_def.ParseFromString(f.read())
        f.close()
        model = get_session()
        model.graph.as_default()
        import_graph_def(graph_def)
    
    # TensorFlow Lite
    elif framework in ['tensorflowlite', 'tflite']:
        from tensorflow import lite
        model = lite.Interpreter(model_path=join('models', 'tflite', 'EfficientPose{0}.tflite'.format(model_variant.upper())))
        model.allocate_tensors()
    
    # PyTorch
    elif framework in ['pytorch', 'torch']:
        from imp import load_source
        from torch import load, quantization, backends
        try:
            MainModel = load_source('MainModel', join('models', 'pytorch', 'EfficientPose{0}.py'.format(model_variant.upper())))
        except:
            print('\n##########################################################################################################')
            print('Desired model "EfficientPose{0}Lite" not available in PyTorch. Please select among "RT", "I", "II", "III" or "IV".'.format(model_variant.split('lite')[0].upper()))
            print('##########################################################################################################\n')
            return False, False
        model = load(join('models', 'pytorch', 'EfficientPose{0}'.format(model_variant.upper())))
        model.eval()
        qconfig = quantization.get_default_qconfig('qnnpack')
        backends.quantized.engine = 'qnnpack'
            
    return model, {'rt': 224, 'i': 256, 'ii': 368, 'iii': 480, 'iv': 600, 'rt_lite': 224, 'i_lite': 256, 'ii_lite': 368}[model_variant]
コード例 #3
0
def load_inference_graph():
    print("> ====== Loading frozen graph into memory")
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = GraphDef()
        with tf.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
        sess = Session(graph=detection_graph, config=config)
    print(">  ====== Inference graph loaded.")
    return detection_graph, sess
コード例 #4
0
def main():
    args = parse_args()

    data = args.input.read_bytes()

    mmfs = load_memmapped_fs(data)
    ROOT = 'memmapped_package://.'

    graph_def = GraphDef()
    graph_def.ParseFromString(mmfs[ROOT])

    undo_mmap(graph_def, mmfs)

    args.output.write_bytes(graph_def.SerializeToString())
コード例 #5
0
ファイル: root_model.py プロジェクト: johSchm/magpie
 def load(self, path, protocol_buffer=False):
     """ Saves the learn.
     :param path
     :param protocol_buffer
     :return protocol_buffer = False: learn
             protocol_buffer = True:  graph
     """
     if protocol_buffer:
         with Session() as sess:
             with gfile.FastGFile(os.path.join(path, 'learn'), 'rb') as file:
                 graph_def = GraphDef()
                 graph_def.ParseFromString(file.read())
                 sess.graph.as_default()
                 return tf.import_graph_def(graph_def)
     else:
         self._model = tf.keras.models.load_model(path)
         return self._model
コード例 #6
0
    def __init__(self,
                 checkpoint_filename,
                 input_name="images",
                 output_name="features"):
        self.session = Session()
        with GFile(checkpoint_filename, "rb") as file_handle:
            graph_def = GraphDef()
            graph_def.ParseFromString(file_handle.read())
        import_graph_def(graph_def, name="net")
        self.input_var = get_default_graph().get_tensor_by_name("net/%s:0" %
                                                                input_name)
        self.output_var = get_default_graph().get_tensor_by_name("net/%s:0" %
                                                                 output_name)

        assert len(self.output_var.get_shape()) == 2
        assert len(self.input_var.get_shape()) == 4
        self.feature_dim = self.output_var.get_shape().as_list()[-1]
        self.image_shape = self.input_var.get_shape().as_list()[1:]
コード例 #7
0
def load(file):
    if (file.endswith(".pbtxt")):
        f = open(file, "r")
        protobuf = text_format.Parse(f.read(), GraphDef())
        # Import the graph protobuf into our new graph.
        g = tf.Graph()
        with g.as_default():
            tf.import_graph_def(graph_def=protobuf, name="")
        return g
    elif (file.endswith(".pb")):
        f = open(file, "rb")
        graph_def = GraphDef()
        graph_def.ParseFromString(f.read())
        g = tf.Graph()
        with g.as_default():
            tf.import_graph_def(graph_def, name='')
        return g
    elif file.endswith(".meta"):
        g = tf.Graph()
        with g.as_default():
            tf.train.import_meta_graph(file)
        return g