def read_devkit(f): """Read relevant information from the development kit archive. Parameters ---------- f : str or file-like object The filename or file-handle for the gzipped TAR archive containing the ILSVRC2010 development kit. Returns ------- synsets : ndarray, 1-dimensional, compound dtype See :func:`read_metadata_mat_file` for details. cost_matrix : ndarray, 2-dimensional, uint8 See :func:`read_metadata_mat_file` for details. raw_valid_groundtruth : ndarray, 1-dimensional, int16 The labels for the ILSVRC2010 validation set, distributed with the development kit code. """ with tar_open(f) as tar: # Metadata table containing class hierarchy, textual descriptions, etc. meta_mat = tar.extractfile(DEVKIT_META_PATH) synsets, cost_matrix = read_metadata_mat_file(meta_mat) # Raw validation data groundtruth, ILSVRC2010 IDs. Confusingly # distributed inside the development kit archive. raw_valid_groundtruth = numpy.loadtxt( tar.extractfile(DEVKIT_VALID_GROUNDTRUTH_PATH), dtype=numpy.int16) return synsets, cost_matrix, raw_valid_groundtruth
def read_devkit(f): """Read relevant information from the development kit archive. Parameters ---------- f : str or file-like object The filename or file-handle for the gzipped TAR archive containing the ILSVRC2010 development kit. Returns ------- synsets : ndarray, 1-dimensional, compound dtype See :func:`read_metadata_mat_file` for details. cost_matrix : ndarray, 2-dimensional, uint8 See :func:`read_metadata_mat_file` for details. raw_valid_groundtruth : ndarray, 1-dimensional, int16 The labels for the ILSVRC2010 validation set, distributed with the development kit code. """ with tar_open(f) as tar: # Metadata table containing class hierarchy, textual descriptions, etc. meta_mat = tar.extractfile(DEVKIT_META_PATH) synsets, cost_matrix = read_metadata_mat_file(meta_mat) # Raw validation data groundtruth, ILSVRC2010 IDs. Confusingly # distributed inside the development kit archive. raw_valid_groundtruth = numpy.loadtxt(tar.extractfile( DEVKIT_VALID_GROUNDTRUTH_PATH), dtype=numpy.int16) return synsets, cost_matrix, raw_valid_groundtruth
def prepare_metadata(devkit_archive): """Extract dataset metadata required for HDF5 file setup. Parameters ---------- devkit_archive : str or file-like object The filename or file-handle for the gzipped TAR archive containing the ILSVRC2012 development kit. Returns ------- n_train : int The number of examples in the training set. valid_groundtruth : ndarray, 1-dimensional An ndarray containing the validation set groundtruth in terms of 0-based class indices. n_test : int The number of examples in the test set wnid_map : dict A dictionary that maps WordNet IDs to 0-based class indices. """ # Read what's necessary from the development kit. synsets, raw_valid_groundtruth = read_devkit(devkit_archive) # Mapping to take WordNet IDs to our internal 0-999 encoding. wnid_map = dict( zip((s.decode('utf8') for s in synsets['WNID']), xrange(1000))) # Map the 'ILSVRC2012 ID' to our zero-based ID. ilsvrc_id_to_zero_based = dict( zip(synsets['ILSVRC2012_ID'], xrange(len(synsets)))) # Map the validation set groundtruth to 0-999 labels. valid_groundtruth = [ ilsvrc_id_to_zero_based[id_] for id_ in raw_valid_groundtruth ] # Get number of test examples from the test archive with tar_open(TEST_IMAGES_TAR) as f: n_test = sum(1 for _ in f) # Ascertain the number of filenames to prepare appropriate sized # arrays. n_train = int(synsets['num_train_images'].sum()) log.info('Training set: {} images'.format(n_train)) log.info('Validation set: {} images'.format(len(valid_groundtruth))) log.info('Test set: {} images'.format(n_test)) n_total = n_train + len(valid_groundtruth) + n_test log.info('Total (train/valid): {} images'.format(n_total)) return n_train, valid_groundtruth, n_test, wnid_map
def train_set_producer(socket, train_archive, patch_archive, wnid_map): """Load/send images from the training set TAR file or patch images. Parameters ---------- socket : :class:`zmq.Socket` PUSH socket on which to send loaded images. train_archive : str or file-like object Filename or file handle for the TAR archive of training images. patch_archive : str or file-like object Filename or file handle for the TAR archive of patch images. wnid_map : dict A dictionary that maps WordNet IDs to 0-based class indices. Used to decode the filenames of the inner TAR files. """ patch_images = extract_patch_images(patch_archive, 'train') num_patched = 0 with tar_open(train_archive) as tar: for inner_tar_info in tar: with tar_open(tar.extractfile(inner_tar_info.name)) as inner: wnid = inner_tar_info.name.split('.')[0] class_index = wnid_map[wnid] filenames = sorted(info.name for info in inner if info.isfile()) images_gen = (load_from_tar_or_patch(inner, filename, patch_images) for filename in filenames) pathless_filenames = (os.path.split(fn)[-1] for fn in filenames) stream = equizip(pathless_filenames, images_gen) for image_fn, (image_data, patched) in stream: if patched: num_patched += 1 socket.send_pyobj((image_fn, class_index), zmq.SNDMORE) socket.send(image_data) if num_patched != len(patch_images): raise ValueError('not all patch images were used')
def prepare_metadata(devkit_archive): """Extract dataset metadata required for HDF5 file setup. Parameters ---------- devkit_archive : str or file-like object The filename or file-handle for the gzipped TAR archive containing the ILSVRC2012 development kit. Returns ------- n_train : int The number of examples in the training set. valid_groundtruth : ndarray, 1-dimensional An ndarray containing the validation set groundtruth in terms of 0-based class indices. n_test : int The number of examples in the test set wnid_map : dict A dictionary that maps WordNet IDs to 0-based class indices. """ # Read what's necessary from the development kit. synsets, raw_valid_groundtruth = read_devkit(devkit_archive) # Mapping to take WordNet IDs to our internal 0-999 encoding. wnid_map = dict(zip((s.decode('utf8') for s in synsets['WNID']), xrange(1000))) # Map the 'ILSVRC2012 ID' to our zero-based ID. ilsvrc_id_to_zero_based = dict(zip(synsets['ILSVRC2012_ID'], xrange(len(synsets)))) # Map the validation set groundtruth to 0-999 labels. valid_groundtruth = [ilsvrc_id_to_zero_based[id_] for id_ in raw_valid_groundtruth] # Get number of test examples from the test archive with tar_open(TEST_IMAGES_TAR) as f: n_test = sum(1 for _ in f) # Ascertain the number of filenames to prepare appropriate sized # arrays. n_train = int(synsets['num_train_images'].sum()) log.info('Training set: {} images'.format(n_train)) log.info('Validation set: {} images'.format(len(valid_groundtruth))) log.info('Test set: {} images'.format(n_test)) n_total = n_train + len(valid_groundtruth) + n_test log.info('Total (train/valid): {} images'.format(n_total)) return n_train, valid_groundtruth, n_test, wnid_map
def extract_patch_images(f, which_set): """Extracts a dict of the "patch images" for ILSVRC2010. Parameters ---------- f : str or file-like object The filename or file-handle to the patch images TAR file. which_set : str Which set of images to extract. One of 'train', 'valid', 'test'. Returns ------- dict A dictionary contains a mapping of filenames (without path) to a bytes object containing the replacement image. Notes ----- Certain images in the distributed archives are blank, or display an "image not available" banner. A separate TAR file of "patch images" is distributed with the corrected versions of these. It is this archive that this function is intended to read. """ if which_set not in ('train', 'valid', 'test'): raise ValueError('which_set must be one of train, valid, or test') which_set = 'val' if which_set == 'valid' else which_set patch_images = {} with tar_open(f) as tar: for info_obj in tar: if not info_obj.name.endswith('.JPEG'): continue # Pretty sure that '/' is used for tarfile regardless of # os.path.sep, but I officially don't care about Windows. tokens = info_obj.name.split('/') file_which_set = tokens[-2] if file_which_set != which_set: continue filename = tokens[-1] patch_images[filename] = tar.extractfile(info_obj.name).read() return patch_images
def other_set_producer(socket, which_set, image_archive, patch_archive, groundtruth): """Push image files read from the valid/test set TAR to a socket. Parameters ---------- socket : :class:`zmq.Socket` PUSH socket on which to send images. which_set : str Which set of images is being processed. One of 'train', 'valid', 'test'. Used for extracting the appropriate images from the patch archive. image_archive : str or file-like object The filename or file-handle for the TAR archive containing images. patch_archive : str or file-like object Filename or file handle for the TAR archive of patch images. groundtruth : iterable Iterable container containing scalar 0-based class index for each image, sorted by filename. """ patch_images = extract_patch_images(patch_archive, which_set) num_patched = 0 with tar_open(image_archive) as tar: filenames = sorted(info.name for info in tar if info.isfile()) images = (load_from_tar_or_patch(tar, filename, patch_images) for filename in filenames) pathless_filenames = (os.path.split(fn)[-1] for fn in filenames) image_iterator = equizip(images, pathless_filenames, groundtruth) for (image_data, patched), filename, class_index in image_iterator: if patched: num_patched += 1 socket.send_pyobj((filename, class_index), zmq.SNDMORE) socket.send(image_data, copy=False) if num_patched != len(patch_images): raise Exception