def _generate_test_data(self, run_name, experiment_name):
    """Generates the test data directory.

    The test data has a single run of the given name, containing:
      - a graph definition and metagraph definition

    Arguments:
      run_name: The directory under self.logdir into which to write
          events.
    """
    run_path = os.path.join(self.logdir, run_name)
    writer = tf.summary.FileWriter(run_path)

    # Add a simple graph event.
    graph_def = tf.GraphDef()
    node1 = graph_def.node.add()
    node1.name = 'a'
    node2 = graph_def.node.add()
    node2.name = 'b'
    node2.attr['very_large_attr'].s = b'a' * 2048  # 2 KB attribute

    meta_graph_def = tf.MetaGraphDef(graph_def=graph_def)

    if self._only_use_meta_graph:
      writer.add_meta_graph(meta_graph_def)
    else:
      writer.add_graph(graph_def)

    writer.flush()
    writer.close()

    # Write data for the run to the database.
    # TODO(nickfelt): Figure out why reseting the graph is necessary.
    tf.reset_default_graph()
    db_writer = tf.contrib.summary.create_db_writer(
        db_uri=self.db_path,
        experiment_name=experiment_name,
        run_name=run_name,
        user_name='user')
    with db_writer.as_default(), tf.contrib.summary.always_record_summaries():
      tf.contrib.summary.scalar('mytag', 1)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.contrib.summary.summary_writer_initializer_op())
      sess.run(tf.contrib.summary.all_summary_ops())
    def _load_graph_using_meta(self, model):
        tf.reset_default_graph()
        graph = tf.Graph()
        graph_def = tf.MetaGraphDef()

        with open(model, "rb") as model_file:
            graph_def.ParseFromString(model_file.read())

        with tf.Session() as sess:
            restorer = tf.train.import_meta_graph(graph_def)
            restorer.restore(sess, re.sub(r'\.meta$', '', model))
            graph_def = tf.graph_util.convert_variables_to_constants(
                sess, graph_def.graph_def, self._config_outputs
            )

        with graph.as_default():
            tf.import_graph_def(graph_def, name='')
        return graph
Beispiel #3
0
def main():
    meta_graph.read_meta_graph_file("/tmp/tf_training/ckpt/rjmodel.ckpt.meta")
    g = tf.MetaGraphDef()
    g.ParseFromString(
        open("/tmp/tf_training/ckpt/rjmodel.ckpt.meta", "rb").read())
    #print("GraphDef from meta_graph_file:", g.graph_def)
    g = tf.GraphDef()
    g.ParseFromString(open("/tmp/stylize_quantized.pb", "rb").read())
    print("GraphDef: ", g)
    #[n for n in g.node if n.name.find("input") != -1] # same for output or any other node you want to make sure is ok

    saved_model = saved_model_pb2.SavedModel()
    saved_model.ParseFromString(
        open("/tmp/SavedModel/saved_model.pb", "rb").read())
    print("saved_model parsed", saved_model.saved_model_schema_version,
          len(saved_model.meta_graphs))
    print("GraphDef from SavedModel file:",
          saved_model.meta_graphs[0].graph_def)
Beispiel #4
0
def predict_func(rows,
                 graph_json,
                 prediction,
                 graph_weights,
                 inp,
                 activation,
                 tf_input,
                 tf_dropout=None,
                 to_keep_dropout=False):
    rows = [r.asDict() for r in rows]
    if len(rows) > 0:
        graph = tf.MetaGraphDef()
        graph = json_format.Parse(graph_json, graph)
        loaded_weights = json.loads(graph_weights)
        loaded_weights = [np.asarray(x) for x in loaded_weights]

        A = [np.asarray(row[inp]) for row in rows]

        new_graph = tf.Graph()
        with tf.Session(graph=new_graph) as sess:
            tf.train.import_meta_graph(graph)
            sess.run(tf.global_variables_initializer())
            tensorflow_set_weights(loaded_weights)
            out_node = tf.get_default_graph().get_tensor_by_name(activation)
            dropout_v = 1.0 if tf_dropout is not None and to_keep_dropout else 0.0
            feed_dict = {
                tf_input: A
            } if tf_dropout is None else {
                tf_input: A,
                tf_dropout: dropout_v
            }

            pred = sess.run(out_node, feed_dict=feed_dict)
            for i in range(0, len(rows)):
                row = rows[i]
                try:
                    # Vectors Dense are handled differently in python 3
                    internal = float(pred[i])
                    row[prediction] = internal
                except:
                    row[prediction] = Vectors.dense(pred[i])
        return [Row(**a) for a in rows]
    return []
Beispiel #5
0
    def _tf_meta_graph_def_to_lgf_meta_graph_info(self, meta_graph_def):
        # Store the serialized partial meta graph def
        partial_meta_graph_def = tf.MetaGraphDef()
        partial_meta_graph_def.saver_def.CopyFrom(meta_graph_def.saver_def)
        for k, v in meta_graph_def.collection_def.items():
            partial_meta_graph_def.collection_def[k].CopyFrom(v)
        for k, v in meta_graph_def.signature_def.items():
            partial_meta_graph_def.signature_def[k].CopyFrom(v)

        partial_meta_graph_def.graph_def.CopyFrom(meta_graph_def.graph_def)
        del (partial_meta_graph_def.graph_def.node[:])

        meta_graph_info = lgf_pb2.MetaGraphInfo()
        meta_graph_info.original_graph_info[
            ImportTFSavedModelBase.
            PARTIAL_META_GRAPH_DEF].v = partial_meta_graph_def.SerializeToString(
            )

        # Get all strings from the partial meta graph def
        proto_strings = self.get_strings_from_proto(partial_meta_graph_def)

        # Need to manually deserialize the collection defs
        for k, collection_def in partial_meta_graph_def.collection_def.items():
            if collection_def.HasField("bytes_list"):
                proto_type = tf_ops.get_collection_proto_type(k)
                for serialized_proto in collection_def.bytes_list.value:
                    proto = proto_type()
                    proto.ParseFromString(serialized_proto)
                    proto_strings.extend(self.get_strings_from_proto(proto))

        # Get all the node names from proto_strings
        node_names = {n.name for n in self._graph_def.node}
        for string in proto_strings:
            if string != "":
                name = self.get_node_name_and_output_index(string)[0]
                if name in node_names:
                    self._required_nodes.add(name)

        # Add required nodes to meta_graph_info
        meta_graph_info.required_nodes.extend(self._required_nodes)

        return meta_graph_info
Beispiel #6
0
    def test_add_new_inits_to_collection(self):
        meta_graph_def = tf.MetaGraphDef()

        orig_table_inits = ['t1', 't2']
        new_table_inits = ['t3', 't4']

        meta_graph_def.collection_def[
            tf.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend(
                orig_table_inits)
        updated_init_names = {
            tf.GraphKeys.TABLE_INITIALIZERS: orig_table_inits + new_table_inits
        }

        meta_graph_editor._add_new_inits_to_collection(meta_graph_def,
                                                       updated_init_names)

        self.assertEqual(
            meta_graph_def.collection_def[
                tf.GraphKeys.TABLE_INITIALIZERS].node_list.value,
            orig_table_inits + new_table_inits)
Beispiel #7
0
def load_metagraph(model, url_prefix, metagraph_cache):
    """Loads and caches a trained model metagraph."""
    filename = os.path.join(metagraph_cache, model + ".metagraph")
    try:
        with tf.io.gfile.GFile(filename, "rb") as f:
            string = f.read()
    except tf.errors.NotFoundError:
        url = url_prefix + "/" + model + ".metagraph"
        try:
            request = urllib.request.urlopen(url)
            string = request.read()
        finally:
            request.close()
        tf.io.gfile.makedirs(os.path.dirname(filename))
        with tf.io.gfile.GFile(filename, "wb") as f:
            f.write(string)
    metagraph = tf.MetaGraphDef()
    metagraph.ParseFromString(string)
    tf.train.import_meta_graph(metagraph)
    return metagraph.signature_def
Beispiel #8
0
    def _ProcessEvent(self, event):
        """Called whenever an event is loaded."""
        if self._first_event_timestamp is None:
            self._first_event_timestamp = event.wall_time

        if event.HasField('file_version'):
            new_file_version = _ParseFileVersion(event.file_version)
            if self.file_version and self.file_version != new_file_version:
                ## This should not happen.
                tf.logging.warn(
                    ('Found new file_version for event.proto. This will '
                     'affect purging logic for TensorFlow restarts. '
                     'Old: {0} New: {1}').format(self.file_version,
                                                 new_file_version))
            self.file_version = new_file_version

        self._MaybePurgeOrphanedData(event)

        ## Process the event.
        # GraphDef and MetaGraphDef are handled in a special way:
        # If no graph_def Event is available, but a meta_graph_def is, and it
        # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
        # If a graph_def Event is available, always prefer it to the graph_def
        # inside the meta_graph_def.
        if event.HasField('graph_def'):
            if self._graph is not None:
                tf.logging.warn(
                    ('Found more than one graph event per run, or there was '
                     'a metagraph containing a graph_def, as well as one or '
                     'more graph events.  Overwriting the graph with the '
                     'newest event.'))
            self._graph = event.graph_def
            self._graph_from_metagraph = False
        elif event.HasField('meta_graph_def'):
            if self._meta_graph is not None:
                tf.logging.warn(
                    ('Found more than one metagraph event per run. '
                     'Overwriting the metagraph with the newest event.'))
            self._meta_graph = event.meta_graph_def
            if self._graph is None or self._graph_from_metagraph:
                # We may have a graph_def in the metagraph.  If so, and no
                # graph_def is directly available, use this one instead.
                meta_graph = tf.MetaGraphDef()
                meta_graph.ParseFromString(self._meta_graph)
                if meta_graph.graph_def:
                    if self._graph is not None:
                        tf.logging.warn((
                            'Found multiple metagraphs containing graph_defs,'
                            'but did not find any graph events.  Overwriting the '
                            'graph with the newest metagraph version.'))
                    self._graph_from_metagraph = True
                    self._graph = meta_graph.graph_def.SerializeToString()
        elif event.HasField('tagged_run_metadata'):
            tag = event.tagged_run_metadata.tag
            if tag in self._tagged_metadata:
                tf.logging.warn(
                    'Found more than one "run metadata" event with tag ' +
                    tag + '. Overwriting it with the newest event.')
            self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
        elif event.HasField('summary'):
            for value in event.summary.value:
                value = data_compat.migrate_value(value)

                if value.HasField('metadata'):
                    tag = value.tag
                    # We only store the first instance of the metadata. This check
                    # is important: the `FileWriter` does strip metadata from all
                    # values except the first one per each tag, but a new
                    # `FileWriter` is created every time a training job stops and
                    # restarts. Hence, we must also ignore non-initial metadata in
                    # this logic.
                    if tag not in self.summary_metadata:
                        self.summary_metadata[tag] = value.metadata
                        plugin_data = value.metadata.plugin_data
                        if plugin_data.plugin_name:
                            self._plugin_to_tag_to_content[
                                plugin_data.plugin_name][tag] = (
                                    plugin_data.content)
                        else:
                            tf.logging.warn((
                                'This summary with tag %r is oddly not associated with a '
                                'plugin.'), tag)

                for summary_type, summary_func in SUMMARY_TYPES.items():
                    if value.HasField(summary_type):
                        datum = getattr(value, summary_type)
                        tag = value.tag
                        if summary_type == 'tensor' and not tag:
                            # This tensor summary was created using the old method that used
                            # plugin assets. We must still continue to support it.
                            tag = value.node_name
                        getattr(self, summary_func)(tag, event.wall_time,
                                                    event.step, datum)
Beispiel #9
0
  def _ProcessEvent(self, event):
    """Called whenever an event is loaded."""
    if self._first_event_timestamp is None:
      self._first_event_timestamp = event.wall_time

    if event.HasField('file_version'):
      new_file_version = _ParseFileVersion(event.file_version)
      if self.file_version and self.file_version != new_file_version:
        ## This should not happen.
        tf.logging.warn(('Found new file_version for event.proto. This will '
                         'affect purging logic for TensorFlow restarts. '
                         'Old: {0} New: {1}').format(self.file_version,
                                                     new_file_version))
      self.file_version = new_file_version

    self._MaybePurgeOrphanedData(event)

    ## Process the event.
    # GraphDef and MetaGraphDef are handled in a special way:
    # If no graph_def Event is available, but a meta_graph_def is, and it
    # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
    # If a graph_def Event is available, always prefer it to the graph_def
    # inside the meta_graph_def.
    if event.HasField('graph_def'):
      if self._graph is not None:
        tf.logging.warn(
            ('Found more than one graph event per run, or there was '
             'a metagraph containing a graph_def, as well as one or '
             'more graph events.  Overwriting the graph with the '
             'newest event.'))
      self._graph = event.graph_def
      self._graph_from_metagraph = False
    elif event.HasField('meta_graph_def'):
      if self._meta_graph is not None:
        tf.logging.warn(('Found more than one metagraph event per run. '
                         'Overwriting the metagraph with the newest event.'))
      self._meta_graph = event.meta_graph_def
      if self._graph is None or self._graph_from_metagraph:
        # We may have a graph_def in the metagraph.  If so, and no
        # graph_def is directly available, use this one instead.
        meta_graph = tf.MetaGraphDef()
        meta_graph.ParseFromString(self._meta_graph)
        if meta_graph.graph_def:
          if self._graph is not None:
            tf.logging.warn(
                ('Found multiple metagraphs containing graph_defs,'
                 'but did not find any graph events.  Overwriting the '
                 'graph with the newest metagraph version.'))
          self._graph_from_metagraph = True
          self._graph = meta_graph.graph_def.SerializeToString()
    elif event.HasField('tagged_run_metadata'):
      tag = event.tagged_run_metadata.tag
      if tag in self._tagged_metadata:
        tf.logging.warn('Found more than one "run metadata" event with tag ' +
                        tag + '. Overwriting it with the newest event.')
      self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
    elif event.HasField('summary'):
      for value in event.summary.value:
        if (value.HasField('tensor') and
            value.tag.startswith(HEALTH_PILL_EVENT_TAG_PREFIX)):
          self._ProcessHealthPillSummary(value, event)
        else:
          for summary_type, summary_func in SUMMARY_TYPES.items():
            if value.HasField(summary_type):
              datum = getattr(value, summary_type)
              tag = value.node_name if summary_type == 'tensor' else value.tag
              getattr(self, summary_func)(tag, event.wall_time, event.step,
                                          datum)
 def __init__(self, meta_file):
     self._meta_def = tf.MetaGraphDef()
     with open(meta_file, 'rb') as fd:
         self._meta_def.ParseFromString(fd.read())
Beispiel #11
0
    def start_service(self, tensorflowGraph, optimizer):
        """
        Asynchronous flask service. This may be a bit confusing why the server starts here and not init.
        It is basically because this is ran in a separate process, and when python call fork, we want to fork from this
        thread and not the master thread
        """
        app = Flask(__name__)
        self.app = app
        max_errors = self.iters
        lock = RWLock()

        server = tf.train.Server.create_local_server()
        graph = tf.MetaGraphDef()
        metagraph = json_format.Parse(tensorflowGraph, graph)
        ng = tf.Graph()
        with ng.as_default():
            tf.train.import_meta_graph(metagraph)
            loss_variable = tf.get_collection(tf.GraphKeys.LOSSES)[0]
            grads = tf.gradients(loss_variable, tf.trainable_variables())
            grads = list(zip(grads, tf.trainable_variables()))
            train_op = optimizer.apply_gradients(grads)
            init = tf.global_variables_initializer()

        glob_session = tf.Session(server.target, graph=ng)
        with ng.as_default():
            with glob_session.as_default():
                glob_session.run(init)
                self.weights = tensorflow_get_weights()

        cont = itertools.count()
        lock_acquired = self.acquire_lock

        @app.route('/')
        def home():
            return 'Lifeomic'

        @app.route('/parameters', methods=['GET'])
        def get_parameters():
            if lock_acquired:
                lock.acquire_read()
            vs = pickle.dumps(self.weights)
            if lock_acquired:
                lock.release()
            return vs

        @app.route('/update', methods=['POST'])
        def update_parameters():
            with ng.as_default():
                gradients = pickle.loads(request.data)
                nu_feed = {}
                for x, grad_var in enumerate(grads):
                    nu_feed[grad_var[0]] = gradients[x]

                if lock_acquired:
                    lock.acquire_write()

                with glob_session.as_default():
                    try:
                        glob_session.run(train_op, feed_dict=nu_feed)
                        self.weights = tensorflow_get_weights()
                    except:
                        error_cnt = cont.next()
                        if error_cnt >= max_errors:
                            raise Exception("Too many failures during training")
                    finally:
                        if lock_acquired:
                            lock.release()

            return 'completed'

        self.app.run(host='0.0.0.0', debug=True, threaded=True, use_reloader=False, port=5000)
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--export",
                        default=None,
                        help="output_dir from --mode export run")
    parser.add_argument("--frozen", default=None, help="frozen graph pb file")
    parser.add_argument("--tflite", default=None, help="tflite file")
    parser.add_argument("--input", default='Untitled.png', help="input image")
    parser.add_argument("--output", default='out.png', help="output image")
    a = parser.parse_args()

    out_files = [a.output]
    im_files = [a.input]
    images_cv = [cv2.resize(cv2.imread(f), (256, 256)) for f in im_files]
    images = np.array(images_cv, dtype=np.float32)
    images = images / 255.0

    if a.tflite:
        interpreter = tf.lite.Interpreter(model_path=a.tflite)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        print(input_details)
        print(output_details)

        interpreter.set_tensor(input_details[0]['index'], images)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details[0]['index'])
        output = output[:, :, :, ::-1]
        output = output * 255

        print("Writing " + out_files[0])
        cv2.imwrite(out_files[0], output[0])

    if a.frozen:
        with tf.Session() as sess:
            with gfile.FastGFile(a.frozen, 'rb') as f:
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                sess.graph.as_default()
                tf.import_graph_def(graph_def, name='')

            input = tf.get_default_graph().get_tensor_by_name("TFLiteInput:0")
            output = tf.get_default_graph().get_tensor_by_name(
                "TFLiteOutput:0")
            output = output[:, :, :, ::-1]
            output = output * 255

            print("Writing " + out_files[0])
            cv2.imwrite(out_files[0],
                        output.eval({'TFLiteInput:0': images})[0])

    if a.export:
        with tf.Session() as sess:
            with gfile.FastGFile(os.path.join(a.export, 'export.meta'),
                                 'rb') as f:
                meta_graph_def = tf.MetaGraphDef()
                meta_graph_def.ParseFromString(f.read())
                tf.train.import_meta_graph(meta_graph_def)
                checkpoint = tf.train.latest_checkpoint(a.export)
                restore_saver = tf.train.Saver()
                restore_saver.restore(sess, checkpoint)

            input = tf.get_default_graph().get_tensor_by_name("TFLiteInput:0")
            output = tf.get_default_graph().get_tensor_by_name(
                "TFLiteOutput:0")
            output = output[:, :, :, ::-1]
            output = output * 255

            print("Writing " + out_files[0])
            cv2.imwrite(out_files[0],
                        output.eval({'TFLiteInput:0': images})[0])
Beispiel #13
0
 def get_meta_graph_copy(self, tags=None):
     """Returns a copy of a MetaGraph with the identical set of tags."""
     meta_graph = self.get_meta_graph(tags)
     copy = tf.MetaGraphDef()
     copy.CopyFrom(meta_graph)
     return copy
Beispiel #14
0
    def test_freeze_then_sparsify(self, freeze_mock, graph_transform_mock):
        tag_name = 'tag'
        input_nodes = 'input_nodes'
        output_nodes = 'output_nodes'
        freeze_transform = 'freeze_graph'
        sparsify_transform = 'sparsify_gather'

        base_meta_graph_def = tf.MetaGraphDef()

        # Add a table initializer.
        table_init_name = 'table_init'
        node_def = tf.NodeDef(name=table_init_name, op='InitializeTableV2')
        base_meta_graph_def.graph_def.node.extend([node_def])

        # Add a group_deps node.
        group_deps_name = 'group_deps'
        node_def = tf.NodeDef(name=group_deps_name, op='NoOp')
        node_def.input.extend(['^table_init'])
        base_meta_graph_def.graph_def.node.extend([node_def])

        base_meta_graph_def.collection_def[
            tf.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend(
                [table_init_name])
        base_meta_graph_def.collection_def[
            tf.saved_model.constants.
            LEGACY_INIT_OP_KEY].node_list.value.extend([group_deps_name])

        # Expected metagraphdef.
        expected_meta_graph_def = tf.MetaGraphDef()
        expected_meta_graph_def.CopyFrom(base_meta_graph_def)
        expected_meta_graph_def.meta_info_def.tags.append(tag_name)

        transformed_graph_def = tf.GraphDef()
        transformed_graph_def.CopyFrom(expected_meta_graph_def.graph_def)
        freeze_mock.return_value = transformed_graph_def
        graph_transform_mock.return_value = transformed_graph_def

        # Add unsaved init node.
        unsaved_init_name = 'unsaved_node'
        node_def = tf.NodeDef(name=unsaved_init_name, op='NoOp')
        base_meta_graph_def.graph_def.node.extend([node_def])

        # Add a saver.
        base_meta_graph_def.saver_def.filename_tensor_name = 'node1'
        base_meta_graph_def.saver_def.save_tensor_name = 'node3'
        base_meta_graph_def.saver_def.restore_op_name = 'node6'

        transformed_meta_graph_def = meta_graph_editor.meta_graph_editor(
            base_meta_graph_def, [input_nodes], [output_nodes],
            [freeze_transform, sparsify_transform], [tag_name])

        self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
        freeze_mock.assert_called_once_with(base_meta_graph_def.graph_def,
                                            [output_nodes], [table_init_name],
                                            group_deps_name,
                                            base_meta_graph_def.saver_def,
                                            None)
        graph_transform_mock.assert_called_once_with(transformed_graph_def, [
            input_nodes
        ], [output_nodes, group_deps_name, table_init_name], [
            sparsify_transform + '(group_init_node="sparify_gather_init_op")'
        ])
Beispiel #15
0
def handle_model(data, graph_json, tfInput, tfLabel=None,
                 master_url='localhost:5000', iters=1000,
                 mini_batch_size=-1, shuffle=True,
                 mini_stochastic_iters=-1, verbose=0, loss_callback=None):
    is_supervised = tfLabel is not None
    features, labels = handle_features(data, is_supervised)

    gd = tf.MetaGraphDef()
    gd = json_format.Parse(graph_json, gd)
    new_graph = tf.Graph()
    with tf.Session(graph=new_graph) as sess:
        tf.train.import_meta_graph(gd)
        loss_variable = tf.get_collection(tf.GraphKeys.LOSSES)[0]
        sess.run(tf.global_variables_initializer())
        trainable_variables = tf.trainable_variables()
        grads = tf.gradients(loss_variable, trainable_variables)
        grads = list(zip(grads, trainable_variables))
        partition_id = uuid.uuid4().hex
        for i in range(0, iters):
            weights = get_server_weights(master_url)
            tensorflow_set_weights(weights, vs=trainable_variables)
            if shuffle:
                features, labels = handle_shuffle(features, labels)

            if mini_stochastic_iters >= 1:
                for _ in range(0, mini_stochastic_iters):
                    gradients = []
                    feed_dict = handle_feed_dict(features, tfInput, tfLabel, labels, mini_batch_size)
                    for x in range(len(grads)):
                        gradients.append(grads[x][0].eval(feed_dict=feed_dict))
                    try:
                        put_deltas_to_server(gradients, master_url)
                    except Exception:
                        print("Timeout error from partition %s" % partition_id)
            elif mini_batch_size >= 1:
                for r in range(0, len(features), mini_batch_size):
                    gradients = []
                    weights = get_server_weights(master_url)
                    tensorflow_set_weights(weights, vs=trainable_variables)
                    feed_dict = handle_feed_dict(features, tfInput, tfLabel, labels, mini_batch_size, idx=r)
                    for x in range(len(grads)):
                        gradients.append(grads[x][0].eval(feed_dict=feed_dict))
                    try:
                        put_deltas_to_server(gradients, master_url)
                    except Exception:
                        print("Timeout error from partition %s" % partition_id)
            else:
                gradients = []
                feed_dict = handle_feed_dict(features, tfInput, tfLabel, labels, mini_batch_size)
                for x in range(len(grads)):
                    gradients.append(grads[x][0].eval(feed_dict=feed_dict))
                try:
                    put_deltas_to_server(gradients, master_url)
                except Exception:
                    print("Timeout error from partition %s" % partition_id)

            if verbose or loss_callback:
                feed_dict = handle_feed_dict(features, tfInput, tfLabel, labels, -1)
                loss = sess.run(loss_variable, feed_dict=feed_dict)
                if verbose:
                    print("Partition Id: %s, Iteration: %i, Loss: %f" % (partition_id, i, loss))
                if loss_callback:
                    loss_callback(loss, i, partition_id)
Beispiel #16
0
def load_tf_graph_def(graph_file_name: str = "",
                      is_binary: bool = True,
                      checkpoint: str = "",
                      model_dir: str = "",
                      saved_model_tags: list = [],
                      meta_graph_file: str = "",
                      user_output_node_names_list: list = []):
    # As a provisional solution, use a native TF methods to load a model protobuf
    graph_def = tf.GraphDef()
    if isinstance(graph_file_name, str) and (re.match('.*\.(ckpt|meta)$',
                                                      graph_file_name)):
        print(
            '[ WARNING ] The value for the --input_model command line parameter ends with ".ckpt" or ".meta" '
            'extension.\n'
            'It means that the model is not frozen.\n'
            'To load non frozen model to Model Optimizer run:'
            '\n\n1. For "*.ckpt" file:'
            '\n- if inference graph is in binary format'
            '\npython3 mo_tf.py --input_model "path/to/inference_graph.pb" --input_checkpoint "path/to/*.ckpt"'
            '\n- if inference graph is in text format'
            '\npython3 mo_tf.py --input_model "path/to/inference_graph.pbtxt" --input_model_is_text '
            '--input_checkpoint "path/to/*.ckpt"'
            '\n\n2. For "*.meta" file:'
            '\npython3 mo_tf.py --input_meta_graph "path/to/*.meta"')
    variables_values = {}
    try:
        if graph_file_name and not meta_graph_file and not checkpoint:
            # frozen graph
            return read_file_to_graph_def(graph_def, graph_file_name,
                                          is_binary), variables_values
        if graph_file_name and not meta_graph_file and checkpoint:
            # inference graph and checkpoint
            graph_def = read_file_to_graph_def(graph_def, graph_file_name,
                                               is_binary)
            outputs = get_output_node_names_list(graph_def,
                                                 user_output_node_names_list)
            if os.path.isfile(checkpoint):
                graph_def = freeze_checkpoint(graph_def=graph_def,
                                              checkpoint=checkpoint,
                                              output_node_names=outputs)
            elif os.path.isdir(checkpoint):
                graph_def, variables_values = freeze_checkpoints(
                    graph_def=graph_def,
                    checkpoint_dir=checkpoint,
                    output_node_names=outputs)
            # we are sure that checkpoint is existing file or directory due to cli_parser configuration
            return graph_def, variables_values
        if not graph_file_name and meta_graph_file:
            meta_graph_file = deducing_metagraph_path(meta_graph_file)
            input_meta_graph_def = read_file_to_graph_def(
                tf.MetaGraphDef(), meta_graph_file, is_binary)
            # pylint: disable=no-member
            with tf.Session() as sess:
                restorer = tf.train.import_meta_graph(input_meta_graph_def)
                restorer.restore(sess, re.sub('\.meta$', '', meta_graph_file))
                outputs = get_output_node_names_list(
                    input_meta_graph_def.graph_def,
                    user_output_node_names_list)
                graph_def = tf.graph_util.convert_variables_to_constants(
                    sess, input_meta_graph_def.graph_def, outputs)
                return graph_def, variables_values
        if model_dir:
            # saved model directory
            tags = saved_model_tags if saved_model_tags is not None else [
                tf.saved_model.tag_constants.SERVING
            ]
            with tf.Session() as sess:
                meta_graph_def = tf.saved_model.loader.load(
                    sess, tags, model_dir)
                outputs = get_output_node_names_list(
                    meta_graph_def.graph_def, user_output_node_names_list)
                graph_def = tf.graph_util.convert_variables_to_constants(
                    sess, meta_graph_def.graph_def, outputs)
                return graph_def, variables_values
    except Exception as e:
        raise FrameworkError('Cannot load input model: {}', e) from e
    raise Error("Unknown configuration of input model parameters")
Beispiel #17
0
    def _ProcessEvent(self, event):
        """Called whenever an event is loaded."""
        if self._first_event_timestamp is None:
            self._first_event_timestamp = event.wall_time

        if event.HasField('file_version'):
            new_file_version = _ParseFileVersion(event.file_version)
            if self.file_version and self.file_version != new_file_version:
                ## This should not happen.
                tf.logging.warn(
                    ('Found new file_version for event.proto. This will '
                     'affect purging logic for TensorFlow restarts. '
                     'Old: {0} New: {1}').format(self.file_version,
                                                 new_file_version))
            self.file_version = new_file_version

        self._MaybePurgeOrphanedData(event)

        ## Process the event.
        # GraphDef and MetaGraphDef are handled in a special way:
        # If no graph_def Event is available, but a meta_graph_def is, and it
        # contains a graph_def, then use the meta_graph_def.graph_def as our graph.
        # If a graph_def Event is available, always prefer it to the graph_def
        # inside the meta_graph_def.
        if event.HasField('graph_def'):
            # print('graph_def')
            if self._graph is not None:
                tf.logging.warn(
                    ('Found more than one graph event per run, or there was '
                     'a metagraph containing a graph_def, as well as one or '
                     'more graph events.  Overwriting the graph with the '
                     'newest event.'))
            self._graph = event.graph_def
            self._graph_from_metagraph = False
        elif event.HasField('meta_graph_def'):
            # print('meta_graph_def')
            if self._meta_graph is not None:
                tf.logging.warn(
                    ('Found more than one metagraph event per run. '
                     'Overwriting the metagraph with the newest event.'))
            self._meta_graph = event.meta_graph_def
            if self._graph is None or self._graph_from_metagraph:
                # We may have a graph_def in the metagraph.  If so, and no
                # graph_def is directly available, use this one instead.
                meta_graph = tf.MetaGraphDef()
                meta_graph.ParseFromString(self._meta_graph)
                if meta_graph.graph_def:
                    if self._graph is not None:
                        tf.logging.warn((
                            'Found multiple metagraphs containing graph_defs,'
                            'but did not find any graph events.  Overwriting the '
                            'graph with the newest metagraph version.'))
                    self._graph_from_metagraph = True
                    self._graph = meta_graph.graph_def.SerializeToString()
        elif event.HasField('tagged_run_metadata'):
            # print('tagged_run_metadata')
            tag = event.tagged_run_metadata.tag
            if tag in self._tagged_metadata:
                tf.logging.warn(
                    'Found more than one "run metadata" event with tag ' +
                    tag + '. Overwriting it with the newest event.')
            self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
import tensorflow as tf
import sys

gf = tf.MetaGraphDef()
gf.ParseFromString(open(sys.argv[1], "rb").read())
print(gf)
Beispiel #19
0
def convert_int8(input_model_dir, output_model_dir, batch_size, precision_mode,
                 calib_image_dir, input_tensor, output_tensor, epochs):

    # (TODO) Need to check if we need Tesla T4 when conversion.
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Get path to calibration data.
    calibration_files = get_calibration_files(calib_image_dir, 'validation*')

    # Create dataset and apply preprocess
    # (TODO) Get num cpus to set appropriate number to num_parallel_calls
    dataset = tf.data.TFRecordDataset(calibration_files)
    dataset = dataset.apply(
        tf.contrib.data.map_and_batch(
            map_func=preprocess,
            batch_size=batch_size,
            num_parallel_calls=multiprocessing.cpu_count()))
    """
  Step 1: Creating the calibration graph.
  """

    # Create TF-TRT INT8 calibration graph.
    trt_int8_calib_graph = trt.create_inference_graph(
        input_graph_def=None,
        outputs=[output_tensor],
        max_batch_size=batch_size,
        input_saved_model_dir=input_model_dir,
        precision_mode=precision_mode)

    # Calibrate graph.
    with tf.Session(graph=tf.Graph(), config=config) as sess:
        tf.logging.info('preparing calibration data...')
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()

        tf.logging.info('Loading INT8 calibration graph...')
        output_node = tf.import_graph_def(trt_int8_calib_graph,
                                          return_elements=[output_tensor],
                                          name='')

        tf.logging.info('Calibrate model with calibration data...')
        for _ in range(epochs):
            sess.run(output_node,
                     feed_dict={input_tensor: sess.run(next_element)[0]})
    """
  Step 2: Converting the calibration graph to inference graph
  """
    tf.logging.info('Creating TF-TRT INT8 inference engine...')
    trt_int8_calibrated_graph = trt.calib_graph_to_infer_graph(
        trt_int8_calib_graph)

    # Copy MetaGraph from base model.
    with tf.Session(graph=tf.Graph(), config=config) as sess:
        base_model = tf.saved_model.loader.load(
            sess, [tf.saved_model.tag_constants.SERVING], input_model_dir)

        metagraph = tf.MetaGraphDef()
        metagraph.graph_def.CopyFrom(trt_int8_calibrated_graph)
        for key in base_model.collection_def:
            if key not in [
                    'variables', 'local_variables', 'model_variables',
                    'trainable_variables', 'train_op', 'table_initializer'
            ]:
                metagraph.collection_def[key].CopyFrom(
                    base_model.collection_def[key])

        metagraph.meta_info_def.CopyFrom(base_model.meta_info_def)
        for key in base_model.signature_def:
            metagraph.signature_def[key].CopyFrom(
                base_model.signature_def[key])

    saved_model_builder = (
        tf.saved_model.builder.SavedModelBuilder(output_model_dir))

    # Write SavedModel with INT8 precision.
    with tf.Graph().as_default():
        tf.graph_util.import_graph_def(trt_int8_calibrated_graph,
                                       return_elements=[output_tensor],
                                       name='')
        with tf.Session(config=config) as sess:
            saved_model_builder.add_meta_graph_and_variables(
                sess, ('serve', ), signature_def_map=metagraph.signature_def)

    # Ignore other meta graphs from the input SavedModel.
    saved_model_builder.save()
    def export_graph(self, saved_model_dir):
        # Collapse supported subgraphs
        light_graph = self.get_collapsed_light_graph(self._light_graph)

        # Check meta graph info
        meta_graph_info = light_graph.meta_graph_info()
        if (tf_saved_model_base_importer.ImportTFSavedModelBase.
                PARTIAL_META_GRAPH_DEF
                not in meta_graph_info.original_graph_info):
            raise ValueError(
                "Could not find partial meta graph def in meta graph info: {}".
                format(meta_graph_info))

        # Convert to TF Graph Def
        meta_graph = tf.MetaGraphDef()
        meta_graph.ParseFromString(meta_graph_info.original_graph_info[
            tf_saved_model_base_importer.ImportTFSavedModelBase.
            PARTIAL_META_GRAPH_DEF].v)
        graph_def = meta_graph.graph_def
        variable_values = {}
        for lnf in light_graph.nodes():
            node_def = graph_def.node.add()
            if lnf.HasField(lgf_pb2.LNF.subgraph.DESCRIPTOR.name):
                self._subgraph_node_to_node_def(node_def, lnf)
            else:
                self._original_node_def(node_def, lnf, variable_values)

        # Add placeholders for nodes that don't exist in the light_graph
        node_names = set(n.name for n in graph_def.node)
        for e in light_graph.input_edges():
            if not light_graph.has_node(e.name):
                graph_def.node.add().CopyFrom(
                    self._edge_to_node_def_placeholder(e))
                node_names.add(e.name)

        # Load the graph def into a session (load custom ops first)
        tf_ops.load_ops()
        with tf.Session(graph=tf.Graph()) as sess:
            if (meta_graph.saver_def.save_tensor_name != ""):
                # TF only allows importing a meta graph def when the saver def
                # is fully initialized. This should only be the case for graphs
                # with TF variables
                assert (len(variable_values) > 0)
                tf.train.import_meta_graph(meta_graph)
            else:
                assert (len(variable_values) == 0)
                tf.import_graph_def(meta_graph.graph_def, name="")

            # Load variables
            for v in (tf.global_variables() + tf.local_variables() +
                      tf.trainable_variables() + tf.model_variables()):
                node_name, _, _ = tf_saved_model_base_importer.ImportTFSavedModelBase.\
                    get_node_name_and_output_index(v.name)
                if node_name not in variable_values:
                    raise ValueError(
                        "Could not find value for variable {}".format(
                            node_name))
                v.load(variable_values[node_name])

            # Get input and output tensors from the graph
            input_tensors = self._get_tf_tensors_from_graph(
                sess.graph, light_graph.input_edges())
            output_tensors = self._get_tf_tensors_from_graph(
                sess.graph, light_graph.output_edges())

            # Create the saved model
            self.save_model(saved_model_dir,
                            sess,
                            input_tensors=input_tensors,
                            output_tensors=output_tensors,
                            output_node_names=light_graph.output_node_names())