예제 #1
0
def read_devkit(f):
    """Read relevant information from the development kit archive.

    Parameters
    ----------
    f : str or file-like object
        The filename or file-handle for the gzipped TAR archive
        containing the ILSVRC2010 development kit.

    Returns
    -------
    synsets : ndarray, 1-dimensional, compound dtype
        See :func:`read_metadata_mat_file` for details.
    cost_matrix : ndarray, 2-dimensional, uint8
        See :func:`read_metadata_mat_file` for details.
    raw_valid_groundtruth : ndarray, 1-dimensional, int16
        The labels for the ILSVRC2010 validation set,
        distributed with the development kit code.

    """
    with tar_open(f) as tar:
        # Metadata table containing class hierarchy, textual descriptions, etc.
        meta_mat = tar.extractfile(DEVKIT_META_PATH)
        synsets, cost_matrix = read_metadata_mat_file(meta_mat)

        # Raw validation data groundtruth, ILSVRC2010 IDs. Confusingly
        # distributed inside the development kit archive.
        raw_valid_groundtruth = numpy.loadtxt(
            tar.extractfile(DEVKIT_VALID_GROUNDTRUTH_PATH), dtype=numpy.int16)
    return synsets, cost_matrix, raw_valid_groundtruth
예제 #2
0
파일: ilsvrc2010.py 프로젝트: Afrik/fuel
def read_devkit(f):
    """Read relevant information from the development kit archive.

    Parameters
    ----------
    f : str or file-like object
        The filename or file-handle for the gzipped TAR archive
        containing the ILSVRC2010 development kit.

    Returns
    -------
    synsets : ndarray, 1-dimensional, compound dtype
        See :func:`read_metadata_mat_file` for details.
    cost_matrix : ndarray, 2-dimensional, uint8
        See :func:`read_metadata_mat_file` for details.
    raw_valid_groundtruth : ndarray, 1-dimensional, int16
        The labels for the ILSVRC2010 validation set,
        distributed with the development kit code.

    """
    with tar_open(f) as tar:
        # Metadata table containing class hierarchy, textual descriptions, etc.
        meta_mat = tar.extractfile(DEVKIT_META_PATH)
        synsets, cost_matrix = read_metadata_mat_file(meta_mat)

        # Raw validation data groundtruth, ILSVRC2010 IDs. Confusingly
        # distributed inside the development kit archive.
        raw_valid_groundtruth = numpy.loadtxt(tar.extractfile(
            DEVKIT_VALID_GROUNDTRUTH_PATH), dtype=numpy.int16)
    return synsets, cost_matrix, raw_valid_groundtruth
예제 #3
0
def prepare_metadata(devkit_archive):
    """Extract dataset metadata required for HDF5 file setup.

    Parameters
    ----------
    devkit_archive : str or file-like object
        The filename or file-handle for the gzipped TAR archive
        containing the ILSVRC2012 development kit.

    Returns
    -------
    n_train : int
        The number of examples in the training set.
    valid_groundtruth : ndarray, 1-dimensional
        An ndarray containing the validation set groundtruth in terms of
        0-based class indices.
    n_test : int
        The number of examples in the test set
    wnid_map : dict
        A dictionary that maps WordNet IDs to 0-based class indices.

    """
    # Read what's necessary from the development kit.
    synsets, raw_valid_groundtruth = read_devkit(devkit_archive)

    # Mapping to take WordNet IDs to our internal 0-999 encoding.
    wnid_map = dict(
        zip((s.decode('utf8') for s in synsets['WNID']), xrange(1000)))

    # Map the 'ILSVRC2012 ID' to our zero-based ID.
    ilsvrc_id_to_zero_based = dict(
        zip(synsets['ILSVRC2012_ID'], xrange(len(synsets))))

    # Map the validation set groundtruth to 0-999 labels.
    valid_groundtruth = [
        ilsvrc_id_to_zero_based[id_] for id_ in raw_valid_groundtruth
    ]

    # Get number of test examples from the test archive
    with tar_open(TEST_IMAGES_TAR) as f:
        n_test = sum(1 for _ in f)

    # Ascertain the number of filenames to prepare appropriate sized
    # arrays.
    n_train = int(synsets['num_train_images'].sum())
    log.info('Training set: {} images'.format(n_train))
    log.info('Validation set: {} images'.format(len(valid_groundtruth)))
    log.info('Test set: {} images'.format(n_test))
    n_total = n_train + len(valid_groundtruth) + n_test
    log.info('Total (train/valid): {} images'.format(n_total))
    return n_train, valid_groundtruth, n_test, wnid_map
예제 #4
0
파일: ilsvrc2010.py 프로젝트: Afrik/fuel
def train_set_producer(socket, train_archive, patch_archive, wnid_map):
    """Load/send images from the training set TAR file or patch images.

    Parameters
    ----------
    socket : :class:`zmq.Socket`
        PUSH socket on which to send loaded images.
    train_archive :  str or file-like object
        Filename or file handle for the TAR archive of training images.
    patch_archive :  str or file-like object
        Filename or file handle for the TAR archive of patch images.
    wnid_map : dict
        A dictionary that maps WordNet IDs to 0-based class indices.
        Used to decode the filenames of the inner TAR files.

    """
    patch_images = extract_patch_images(patch_archive, 'train')
    num_patched = 0
    with tar_open(train_archive) as tar:
        for inner_tar_info in tar:
            with tar_open(tar.extractfile(inner_tar_info.name)) as inner:
                wnid = inner_tar_info.name.split('.')[0]
                class_index = wnid_map[wnid]
                filenames = sorted(info.name for info in inner
                                   if info.isfile())
                images_gen = (load_from_tar_or_patch(inner, filename,
                                                     patch_images)
                              for filename in filenames)
                pathless_filenames = (os.path.split(fn)[-1]
                                      for fn in filenames)
                stream = equizip(pathless_filenames, images_gen)
                for image_fn, (image_data, patched) in stream:
                    if patched:
                        num_patched += 1
                    socket.send_pyobj((image_fn, class_index), zmq.SNDMORE)
                    socket.send(image_data)
    if num_patched != len(patch_images):
        raise ValueError('not all patch images were used')
예제 #5
0
def train_set_producer(socket, train_archive, patch_archive, wnid_map):
    """Load/send images from the training set TAR file or patch images.

    Parameters
    ----------
    socket : :class:`zmq.Socket`
        PUSH socket on which to send loaded images.
    train_archive :  str or file-like object
        Filename or file handle for the TAR archive of training images.
    patch_archive :  str or file-like object
        Filename or file handle for the TAR archive of patch images.
    wnid_map : dict
        A dictionary that maps WordNet IDs to 0-based class indices.
        Used to decode the filenames of the inner TAR files.

    """
    patch_images = extract_patch_images(patch_archive, 'train')
    num_patched = 0
    with tar_open(train_archive) as tar:
        for inner_tar_info in tar:
            with tar_open(tar.extractfile(inner_tar_info.name)) as inner:
                wnid = inner_tar_info.name.split('.')[0]
                class_index = wnid_map[wnid]
                filenames = sorted(info.name for info in inner
                                   if info.isfile())
                images_gen = (load_from_tar_or_patch(inner, filename,
                                                     patch_images)
                              for filename in filenames)
                pathless_filenames = (os.path.split(fn)[-1]
                                      for fn in filenames)
                stream = equizip(pathless_filenames, images_gen)
                for image_fn, (image_data, patched) in stream:
                    if patched:
                        num_patched += 1
                    socket.send_pyobj((image_fn, class_index), zmq.SNDMORE)
                    socket.send(image_data)
    if num_patched != len(patch_images):
        raise ValueError('not all patch images were used')
예제 #6
0
파일: ilsvrc2012.py 프로젝트: Scyfer/fuel
def prepare_metadata(devkit_archive):
    """Extract dataset metadata required for HDF5 file setup.

    Parameters
    ----------
    devkit_archive : str or file-like object
        The filename or file-handle for the gzipped TAR archive
        containing the ILSVRC2012 development kit.

    Returns
    -------
    n_train : int
        The number of examples in the training set.
    valid_groundtruth : ndarray, 1-dimensional
        An ndarray containing the validation set groundtruth in terms of
        0-based class indices.
    n_test : int
        The number of examples in the test set
    wnid_map : dict
        A dictionary that maps WordNet IDs to 0-based class indices.

    """
    # Read what's necessary from the development kit.
    synsets, raw_valid_groundtruth = read_devkit(devkit_archive)

    # Mapping to take WordNet IDs to our internal 0-999 encoding.
    wnid_map = dict(zip((s.decode('utf8') for s in synsets['WNID']),
                        xrange(1000)))

    # Map the 'ILSVRC2012 ID' to our zero-based ID.
    ilsvrc_id_to_zero_based = dict(zip(synsets['ILSVRC2012_ID'],
                                       xrange(len(synsets))))

    # Map the validation set groundtruth to 0-999 labels.
    valid_groundtruth = [ilsvrc_id_to_zero_based[id_]
                         for id_ in raw_valid_groundtruth]

    # Get number of test examples from the test archive
    with tar_open(TEST_IMAGES_TAR) as f:
        n_test = sum(1 for _ in f)

    # Ascertain the number of filenames to prepare appropriate sized
    # arrays.
    n_train = int(synsets['num_train_images'].sum())
    log.info('Training set: {} images'.format(n_train))
    log.info('Validation set: {} images'.format(len(valid_groundtruth)))
    log.info('Test set: {} images'.format(n_test))
    n_total = n_train + len(valid_groundtruth) + n_test
    log.info('Total (train/valid): {} images'.format(n_total))
    return n_train, valid_groundtruth, n_test, wnid_map
예제 #7
0
파일: ilsvrc2010.py 프로젝트: Afrik/fuel
def extract_patch_images(f, which_set):
    """Extracts a dict of the "patch images" for ILSVRC2010.

    Parameters
    ----------
    f : str or file-like object
        The filename or file-handle to the patch images TAR file.
    which_set : str
        Which set of images to extract. One of 'train', 'valid', 'test'.

    Returns
    -------
    dict
        A dictionary contains a mapping of filenames (without path) to a
        bytes object containing the replacement image.

    Notes
    -----
    Certain images in the distributed archives are blank, or display
    an "image not available" banner. A separate TAR file of
    "patch images" is distributed with the corrected versions of
    these. It is this archive that this function is intended to read.

    """
    if which_set not in ('train', 'valid', 'test'):
        raise ValueError('which_set must be one of train, valid, or test')
    which_set = 'val' if which_set == 'valid' else which_set
    patch_images = {}
    with tar_open(f) as tar:
        for info_obj in tar:
            if not info_obj.name.endswith('.JPEG'):
                continue
            # Pretty sure that '/' is used for tarfile regardless of
            # os.path.sep, but I officially don't care about Windows.
            tokens = info_obj.name.split('/')
            file_which_set = tokens[-2]
            if file_which_set != which_set:
                continue
            filename = tokens[-1]
            patch_images[filename] = tar.extractfile(info_obj.name).read()
    return patch_images
예제 #8
0
def extract_patch_images(f, which_set):
    """Extracts a dict of the "patch images" for ILSVRC2010.

    Parameters
    ----------
    f : str or file-like object
        The filename or file-handle to the patch images TAR file.
    which_set : str
        Which set of images to extract. One of 'train', 'valid', 'test'.

    Returns
    -------
    dict
        A dictionary contains a mapping of filenames (without path) to a
        bytes object containing the replacement image.

    Notes
    -----
    Certain images in the distributed archives are blank, or display
    an "image not available" banner. A separate TAR file of
    "patch images" is distributed with the corrected versions of
    these. It is this archive that this function is intended to read.

    """
    if which_set not in ('train', 'valid', 'test'):
        raise ValueError('which_set must be one of train, valid, or test')
    which_set = 'val' if which_set == 'valid' else which_set
    patch_images = {}
    with tar_open(f) as tar:
        for info_obj in tar:
            if not info_obj.name.endswith('.JPEG'):
                continue
            # Pretty sure that '/' is used for tarfile regardless of
            # os.path.sep, but I officially don't care about Windows.
            tokens = info_obj.name.split('/')
            file_which_set = tokens[-2]
            if file_which_set != which_set:
                continue
            filename = tokens[-1]
            patch_images[filename] = tar.extractfile(info_obj.name).read()
    return patch_images
예제 #9
0
파일: ilsvrc2010.py 프로젝트: Afrik/fuel
def other_set_producer(socket, which_set, image_archive, patch_archive,
                       groundtruth):
    """Push image files read from the valid/test set TAR to a socket.

    Parameters
    ----------
    socket : :class:`zmq.Socket`
        PUSH socket on which to send images.
    which_set : str
        Which set of images is being processed. One of 'train', 'valid',
        'test'.  Used for extracting the appropriate images from the patch
        archive.
    image_archive : str or file-like object
        The filename or file-handle for the TAR archive containing images.
    patch_archive : str or file-like object
        Filename or file handle for the TAR archive of patch images.
    groundtruth : iterable
        Iterable container containing scalar 0-based class index for each
        image, sorted by filename.

    """
    patch_images = extract_patch_images(patch_archive, which_set)
    num_patched = 0
    with tar_open(image_archive) as tar:
        filenames = sorted(info.name for info in tar if info.isfile())
        images = (load_from_tar_or_patch(tar, filename, patch_images)
                  for filename in filenames)
        pathless_filenames = (os.path.split(fn)[-1] for fn in filenames)
        image_iterator = equizip(images, pathless_filenames, groundtruth)
        for (image_data, patched), filename, class_index in image_iterator:
            if patched:
                num_patched += 1
            socket.send_pyobj((filename, class_index), zmq.SNDMORE)
            socket.send(image_data, copy=False)
    if num_patched != len(patch_images):
        raise Exception
예제 #10
0
def other_set_producer(socket, which_set, image_archive, patch_archive,
                       groundtruth):
    """Push image files read from the valid/test set TAR to a socket.

    Parameters
    ----------
    socket : :class:`zmq.Socket`
        PUSH socket on which to send images.
    which_set : str
        Which set of images is being processed. One of 'train', 'valid',
        'test'.  Used for extracting the appropriate images from the patch
        archive.
    image_archive : str or file-like object
        The filename or file-handle for the TAR archive containing images.
    patch_archive : str or file-like object
        Filename or file handle for the TAR archive of patch images.
    groundtruth : iterable
        Iterable container containing scalar 0-based class index for each
        image, sorted by filename.

    """
    patch_images = extract_patch_images(patch_archive, which_set)
    num_patched = 0
    with tar_open(image_archive) as tar:
        filenames = sorted(info.name for info in tar if info.isfile())
        images = (load_from_tar_or_patch(tar, filename, patch_images)
                  for filename in filenames)
        pathless_filenames = (os.path.split(fn)[-1] for fn in filenames)
        image_iterator = equizip(images, pathless_filenames, groundtruth)
        for (image_data, patched), filename, class_index in image_iterator:
            if patched:
                num_patched += 1
            socket.send_pyobj((filename, class_index), zmq.SNDMORE)
            socket.send(image_data, copy=False)
    if num_patched != len(patch_images):
        raise Exception