def get_dataset(self, split_name, dataset_dir): assert split_name in ['train', 'val'] file_pattern = os.path.join(dataset_dir, '%s_*.tfrecord' % split_name) reader = tf.TFRecordReader keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 'image/class/label': tf.FixedLenFeature( [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } items_to_handlers = { 'image': slim.tfexample_decoder.Image(), 'label': slim.tfexample_decoder.Tensor('image/class/label'), } decoder = slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) labels_to_names = None if dataset_utils.has_labels(dataset_dir, 'label_map.txt'): labels_to_names = dataset_utils.read_label_file(dataset_dir, 'label_map.txt') return slim.dataset.Dataset( data_sources=file_pattern, reader=reader, decoder=decoder, num_samples=self.splits_to_sizes[split_name], items_to_descriptions=self.items_to_descriptions, num_classes=self.num_classes, labels_to_names=labels_to_names)
def infer(self, image_path): import numpy as np from PIL import Image if self.config is None: tf.logging.error('Configuration is None') return None model_name = self.config['model_name'] checkpoint_path = self.config['checkpoint_path'] labels_to_names = None if dataset_utils.has_labels(checkpoint_path, 'label_map.txt'): labels_to_names = dataset_utils.read_label_file(checkpoint_path, 'label_map.txt') else: tf.logging.error('No label map') return None keys = list(labels_to_names.keys()) with tf.Graph().as_default(): image_string = tf.read_file(image_path) image = tf.image.decode_jpeg(image_string, channels=3) image_preprocessing_fn = preprocessing_factory.get_preprocessing( model_name, is_training=False) network_fn = nets_factory.get_network_fn( model_name, num_classes=len(keys), is_training=False) processed_image = image_preprocessing_fn(image, network_fn.default_image_size, network_fn.default_image_size) image_expanded = tf.expand_dims(processed_image, axis=0) logits, _ = network_fn(image_expanded) probabilites = tf.nn.softmax(logits) predictions = tf.argmax(logits, 1) model_path = tf.train.latest_checkpoint(checkpoint_path) init_fn = slim.assign_from_checkpoint_fn(model_path, slim.get_model_variables(scope_map[model_name])) with tf.Session() as sess: init_fn(sess) probs, pred = sess.run([probabilites, predictions]) result =[] for i in range(len(probs[0])): result.append({'type': labels_to_names[keys[i]], 'prob': str(probs[0][i])}) sorted_result = sorted(result, key=lambda k: float(k['prob']), reverse=True) return sorted_result