Exemplo n.º 1
0
    def __init__(self, base_model_path, scaffold_path, **kwargs):
        self.base_model_path = base_model_path
        self.model_name = base_model_path.split('/')[
            -1]  # ex inception_resnet_v2
        self.scaffold_path = scaffold_path
        self.settings = DEFAULT_SETTINGS

        for key in kwargs:
            self.settings[key] = kwargs[key]

        if 'base_checkpoint_path' not in self.settings:
            checkpoint_dir = os.path.abspath(
                os.path.join(self.base_model_path, 'state'))
            if tf.train.latest_checkpoint(checkpoint_dir) is None:
                base_path = os.path.join(self.base_model_path, 'state',
                                         self.model_name + '.ckpt')
            else:
                base_path = tf.train.latest_checkpoint(checkpoint_dir)
            self.settings['base_checkpoint_path'] = base_path

        self.labels = load_labels(scaffold_path)
        self.num_classes = len(self.labels)
        self.preprocess = preprocessing_factory.get_preprocessing(
            self.model_name, is_training=True)
        self.model_definition = nets_factory.get_network_fn(self.model_name,
                                                            self.num_classes,
                                                            is_training=True)
Exemplo n.º 2
0
    def __init__(self, model_dir, model_name):
        tf.reset_default_graph()
        self.model_dir = model_dir
        self.model_name = model_name  # ex inception_resnet_v2
        self.preprocess = preprocessing_factory.get_preprocessing(
            self.model_name, is_training=False)
        labels = load_labels(model_dir)
        self.labels_by_node_id = {}
        for label_id in labels:
            label = labels[label_id]
            node_id = label.get('node_id', None)
            if node_id is None:
                raise Exception(
                    'No Softmax node_id is known for label {}, aborting'.
                    format(label_id))
            self.labels_by_node_id[node_id] = label

        self.num_classes = len(self.labels_by_node_id)
        self.model_definition = nets_factory.get_network_fn(self.model_name,
                                                            self.num_classes,
                                                            is_training=False)
Exemplo n.º 3
0
with tf.Graph().as_default() as g:
    with g.device(EVAL_DEVICE):
        dataset = dataset_factory.get_dataset('cifar10', 'test', DATA_DIR)

        tf_global_step = slim.get_or_create_global_step()

        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset,
            shuffle=False,
            common_queue_capacity=2 * BATCH_SIZE,
            common_queue_min=BATCH_SIZE)

        [image, label] = provider.get(['image', 'label'])

        image_preprocessing_fn = preprocessing_factory.get_preprocessing(
            'cifarnet', is_training=False)

        image = image_preprocessing_fn(image, 32, 32)

        images, labels = tf.train.batch([image, label],
                                        batch_size=BATCH_SIZE,
                                        num_threads=2,
                                        capacity=5 * BATCH_SIZE)

        logits, end_points = vgg_cifar10.inference(images)

        predictions = tf.argmax(logits, 1)

        accuracy, update_op = slim.metrics.streaming_accuracy(
            predictions, labels)
        tf.scalar_summary('eval/accuracy', accuracy)