예제 #1
0
    def evaluate(checkpoint_file):
        result_file = get_checkpoint_path(
            checkpoint_file) + f".knn{args.top_k}.txt"
        if os.path.isfile(result_file):
            logger.info(f"Skipping evaluation of {result_file}.")
            return
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            SmartInit(checkpoint_file).init(sess)
            for batch_img, batch_idx in tqdm.tqdm(train_ds,
                                                  total=len(train_ds)):
                sess.run(update_buffer,
                         feed_dict={
                             image_input: batch_img,
                             idx_input: batch_idx
                         })

            if hvd.rank() == 0:
                acc = Accuracy()
                val_df = get_imagenet_dataflow(args.data, "val",
                                               local_batch_size)
                val_df.reset_state()

                for batch_img, batch_label in val_df:
                    topk_indices_pred = sess.run(
                        topk_indices, feed_dict={image_input: batch_img})
                    for indices, gt in zip(topk_indices_pred, batch_label):
                        pred = [all_train_files[k][1] for k in indices]
                        top_pred = Counter(pred).most_common(1)[0]
                        acc.feed(top_pred[0] == gt, total=1)
                logger.info(
                    f"Accuracy of {checkpoint_file}: {acc.accuracy} out of {acc.total}"
                )
                with open(result_file, "w") as f:
                    f.write(str(acc.accuracy))
 def __init__(self, model_path, prefix=None):
     """
     Args:
         model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
         prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
     """
     model_path = get_checkpoint_path(model_path)
     self.path = model_path
     self.prefix = prefix
예제 #3
0
파일: misc.py 프로젝트: murph3d/DLD
def save_model(model_paths, model, target="", compact=False):
    """Save a model to given dir"""
    from os import path
    from os import makedirs

    import tensorpack as tp

    from tensorpack.tfutils.varmanip import get_checkpoint_path
    from tensorpack.tfutils.export import ModelExporter

    import misc.logger as logger
    _L = logger.getLogger("Saver")

    save_to_modeldir = target is ""

    for model_path in model_paths:
        # get model path
        real_path = get_checkpoint_path(model_path)
        abs_p = path.realpath(model_path)
        if (not path.isfile(abs_p)):
            _L.error("{} is not a model file".format(model_path))
            continue

        # save to same folder as model
        if (save_to_modeldir):
            target = path.dirname(abs_p)

        # make sure the folder exists
        if not path.exists(target):
            makedirs(target)

        conf = tp.PredictConfig(session_init=tp.get_model_loader(model_path),
                                model=model,
                                input_names=["input"],
                                output_names=["emb"])

        exporter = ModelExporter(conf)
        if (compact):
            out = path.join(target, "{}.pb".format(path.basename(real_path)))
            _L.info("saving {} to {}".format(path.basename(real_path), out))
            exporter.export_compact(out)
        else:
            _L.info("compact saving {} to {}".format(path.basename(real_path),
                                                     target))
            exporter.export_serving(target)
예제 #4
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ls-checkpoint.py

import tensorflow as tf
import numpy as np
import six
import sys
import pprint

from tensorpack.tfutils.varmanip import get_checkpoint_path

if __name__ == '__main__':
    fpath = sys.argv[1]

    if fpath.endswith('.npy'):
        params = np.load(fpath, encoding='latin1').item()
        dic = {k: v.shape for k, v in six.iteritems(params)}
    elif fpath.endswith('.npz'):
        params = dict(np.load(fpath))
        dic = {k: v.shape for k, v in six.iteritems(params)}
    else:
        path = get_checkpoint_path(sys.argv[1])
        reader = tf.train.NewCheckpointReader(path)
        dic = reader.get_variable_to_shape_map()
    pprint.pprint(dic)
예제 #5
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ls-checkpoint.py

import numpy as np
import pprint
import sys
import six
import tensorflow as tf

from tensorpack.tfutils.varmanip import get_checkpoint_path

if __name__ == '__main__':
    fpath = sys.argv[1]

    if fpath.endswith('.npy'):
        params = np.load(fpath, encoding='latin1').item()
        dic = {k: v.shape for k, v in six.iteritems(params)}
    elif fpath.endswith('.npz'):
        params = dict(np.load(fpath))
        dic = {k: v.shape for k, v in six.iteritems(params)}
    else:
        path = get_checkpoint_path(fpath)
        reader = tf.train.NewCheckpointReader(path)
        dic = reader.get_variable_to_shape_map()
    pprint.pprint(dic)
예제 #6
0
파일: misc.py 프로젝트: murph3d/DLD
def test(dataset, model_paths, model=None):
    import tensorpack as tp
    import tensorflow as tf

    from tensorpack.tfutils.varmanip import get_checkpoint_path
    from tensorflow.contrib.tensorboard.plugins import projector

    import numpy as np

    import misc.logger as logger

    import cnn.plotter as plotter

    import struct

    def pairwise(sess):
        from cnn.main import pairwise_distance
        descs = tf.compat.v1.placeholder(tf.float32, (
            None,
            8,
        ), "descs")
        cnn_dists = pairwise_distance(descs, True)
        pred = tp.OnlinePredictor([descs], [cnn_dists], sess=sess)

        def calculate(*args, **kwargs):
            return pred(*args, **kwargs)[0]

        return calculate

    _L = logger.getLogger("test")  # noqa

    sess = tf.compat.v1.Session()
    cnn_summ = tf.compat.v1.summary.FileWriter(tp.logger.get_logger_dir() +
                                               "/cnn")
    lbd_summ = tf.compat.v1.summary.FileWriter(tp.logger.get_logger_dir() +
                                               "/lbd")

    cnn_logdir = cnn_summ.get_logdir()

    # [[Match...], [Label...], [Descriptor...]]
    dataset.reset_state()
    data = fake_data(dataset, len(model_paths), _L)

    pairwise_distance = pairwise(sess)

    for model_path in model_paths:
        # get global step
        real_path = get_checkpoint_path(model_path)
        reader = tf.compat.v1.train.NewCheckpointReader(real_path)
        global_step = reader.get_tensor("global_step")

        # predictor
        pred = get_predictor(real_path, model)

        img_info = [[], []]
        imgs = []

        # summaries
        cnn_summaries = tf.compat.v1.Summary()
        lbd_summaries = tf.compat.v1.Summary()

        # collected data for ROC curves
        roc_data = []

        for batch_n in range(config.batch_num):
            # test the batch
            (emb, heights, labels, left, timgs, tinfo,
             cnn_dists, lbd_dists) = test_batch(pairwise_distance, pred,
                                                next(data), batch_n)
            roc_data.append((labels, left, cnn_dists, lbd_dists))

            # save results of CNN to NPZ file
            # if config.save_results is not None:
            #     from pathlib import Path
            #
            #     # convert to path and use NPZ suffix
            #     file = Path(str(config.save_results))
            #     file = file.with_suffix('.npz')
            #     file = str(file)
            #
            #     # save NPZ to given path
            #     np.savez_compressed(file, cutout_list=list(timgs), cnn_desc_list=list(emb), cnn_dist_matrix=cnn_dists,
            #                         label_list=list(labels), left_list=list(left), target_height_list=list(heights))
            #
            #     _L.debug('Saved CNN results to \'{}\''.format(file))
            if config.save_results:
                nested = dataset.ds.ds
                save_op = getattr(nested, 'save_results', None)
                # check if nested dataset has a callable 'save_results' function
                if callable(save_op):
                    nested.save_results(cnn_desc_list=list(emb),
                                        label_list=list(labels),
                                        left_list=list(left))
                # TODO: warn user that save results is not yet implemented!
                else:
                    pass

            if config.return_results and hasattr(dataset,
                                                 'client') and dataset.client:
                message = struct.pack('II', *cnn_dists.shape[:2])
                message += cnn_dists.tobytes()
                dataset.client.send('c', message, wait=False)

            # generate image output
            if config.tp_imgs:
                cimgs = [
                    np.resize(el, (heights[i], *el.shape[1:]))
                    for i, el in enumerate(timgs)
                ]
                imgs.append(cimgs)
                for i in [0, 1]:
                    img_info[i].append(tinfo[i])

            # generate projection output
            if config.tp_proj:
                mdata_name = "metadata_{}.tsv".format(batch_n)
                mdata_path = "{}/{}".format(cnn_logdir, mdata_name)
                with open(mdata_path, "w") as mfile:
                    for label in labels:
                        mfile.write("{}\n".format(label))
                sprite = make_sprite(timgs, timgs[0].shape, cnn_logdir,
                                     batch_n)
                sprite_size = max(timgs[0].shape)

                embv = tf.Variable(emb, name="embeddings")
                initop = tf.variables_initializer([embv])
                pconf = projector.ProjectorConfig()
                embconf = pconf.embeddings.add()
                embconf.tensor_name = embv.name
                embconf.metadata_path = mdata_name
                embconf.sprite.image_path = sprite
                embconf.sprite.single_image_dim.extend(
                    [sprite_size, sprite_size])
                projector.visualize_embeddings(cnn_summ, pconf)
                sess.run(initop)
                saver = tf.train.Saver()
                saver.save(sess, "{}/embeddings.ckpt".format(cnn_logdir),
                           batch_n)

        # generate ROC
        tap_dists, tan_dists = split_dists(roc_data)

        cnn_plots = plotter.plot_roc(tap_dists[0],
                                     tan_dists[0],
                                     "cnn: {}".format(config.depth),
                                     color="g")
        plotter.plot_roc(tap_dists[1],
                         tan_dists[1],
                         "lbd",
                         color="b",
                         figs=cnn_plots)

        suffix = ""
        for i in range(2):
            img = plotter.plot_to_np(cnn_plots[i])
            img = np.expand_dims(img, axis=0)

            s = tp.summary.create_image_summary(
                "ROC{}/{}".format(suffix, global_step), img)
            cnn_summaries.value.extend(s.value)
            suffix = "_zoomed"

        # get images
        # add image summary
        if config.tp_imgs:
            for i in range(len(img_info[0])):
                for j in range(len(img_info[0][i])):
                    cnn_info = img_info[0][i][j]
                    lbd_info = img_info[1][i][j]

                    generate_tensorboard_img_summary(cnn_summaries, cnn_info,
                                                     imgs[i], "imgs/cnn",
                                                     global_step, i)
                    generate_tensorboard_img_summary(lbd_summaries, lbd_info,
                                                     imgs[i], "imgs/lbd",
                                                     global_step, i)

        cnn_summ.add_summary(cnn_summaries, global_step)
        lbd_summ.add_summary(lbd_summaries, global_step)
    lbd_summ.flush()
    cnn_summ.flush()
    lbd_summ.close()
    cnn_summ.close()
    sess.close()