コード例 #1
0
def _create_validation_cache(archive, output, names, ground_truth, args):
    images0 = sorted(names)

    # Thinning
    images = []
    for i, image in enumerate(images0):
        if i % args.thinning == 0:
            images.append(image)

    def _load_func(index):
        y, name = ground_truth[index], images[index]
        im = imread(archive.extractfile(name), num_channels=3)
        x = _resize_image(im, args.width, args.height, args.mode == 'padding')
        return x, np.array([y]).astype(np.int32)

    from nnabla.utils.data_source import DataSourceWithFileCache
    from nnabla.utils.data_source_implements import SimpleDataSource
    from nnabla.logger import logger

    logger.info('Num of data : {}'.format(len(images)))
    shuffle = False
    if args.shuffle == 'True':
        shuffle = True
    source = SimpleDataSource(_load_func, len(images), shuffle, rng=None)
    DataSourceWithFileCache(source, cache_dir=output, shuffle=args.shuffle)
コード例 #2
0
def _create_validation_cache(archive, output, names, ground_truth, args):
    # ILSVRC2012_devkit_t12/readme.txt
    #     The ground truth of the validation images is in
    #     data/ILSVRC2012_validation_ground_truth.txt, where each line contains
    #     one ILSVRC2012_ID for one image, in the ascending alphabetical order
    #     of the image file names.
    images0 = sorted(names)

    # Thinning
    images = []
    for i, image in enumerate(images0):
        if i % args.thinning == 0:
            images.append(image)

    def _load_func(index):
        y, name = ground_truth[index], images[index]
        im = imread(archive.extractfile(name))
        x = _resize_image(im, args.width, args.height, args.mode == 'padding')
        return x, np.array([y - 1]).astype(np.int32)

    from nnabla.utils.data_source import DataSourceWithFileCache
    from nnabla.utils.data_source_implements import SimpleDataSource
    from nnabla.logger import logger

    logger.info('Num of data : {}'.format(len(images)))
    shuffle = False
    if args.shuffle == 'True':
        shuffle = True
    source = SimpleDataSource(_load_func, len(images), shuffle, rng=None)
    DataSourceWithFileCache(source, cache_dir=output, shuffle=args.shuffle)
コード例 #3
0
ファイル: conv_dataset.py プロジェクト: hixcod/nnabla
def _convert(args, source):
    _, ext = os.path.splitext(args.destination)
    if ext.lower() == '.cache':
        with DataSourceWithFileCache(source,
                                     cache_dir=args.destination,
                                     shuffle=args.shuffle) as ds:
            print('Number of Data: {}'.format(ds.size))
            print('Shuffle:        {}'.format(args.shuffle))
            print('Normalize:      {}'.format(args.normalize))
            pbar = tqdm(total=ds.size)
            for i in range(ds.size):
                ds._get_data(i)
                pbar.update(1)
    else:
        print('Command `conv_dataset` only supports CACHE as destination.')
コード例 #4
0
def _create_train_cache(archive, output, names, args):
    # Read label and wordnet_id
    wid2ind = np.loadtxt(fname=LABEL_WORDNETID, dtype=object, delimiter=",")

    def _get_label(wordnet_id):
        for item in wid2ind:
            if item[1] == wordnet_id:
                return item[0]

    images0 = []
    print("Count image in TAR")
    pbar = tqdm.tqdm(total=len(names), unit='%')
    for name in names:
        category = os.path.splitext(name)[0]
        marchive = tarfile.open(fileobj=archive.extractfile(name))
        for mname in marchive.getnames():
            if re.match(r'{}_[0-9]+\.JPEG'.format(category), mname):
                images0.append((_get_label(name[:9]), name, marchive, mname))
            else:
                print('Invalid file {} includes in tar file'.format(mname))
                exit(-1)
        pbar.update(1)
    pbar.close()

    # Thinning
    images = []
    for i, image in enumerate(images0):
        if i % args.thinning == 0:
            images.append(image)

    def _load_func(index):
        y, name, marchive, mname = images[index]
        im = imread(marchive.extractfile(mname), num_channels=3)
        x = _resize_image(im, args.width, args.height, args.mode == 'padding')
        return x, np.array([y]).astype(np.int32)

    from nnabla.utils.data_source import DataSourceWithFileCache
    from nnabla.utils.data_source_implements import SimpleDataSource
    from nnabla.logger import logger

    logger.info('Num of data : {}'.format(len(images)))
    shuffle = True
    if args.shuffle == 'False':
        shuffle = False
    source = SimpleDataSource(_load_func, len(images), shuffle, rng=None)
    DataSourceWithFileCache(
        source, cache_dir=output, shuffle=args.shuffle)
コード例 #5
0
def _create_train_cache(archive, output, names, synsets_id, args):
    images0 = []
    print("Count image in TAR")
    pbar = tqdm.tqdm(total=len(names), unit='%')
    for name in names:
        category = os.path.splitext(name)[0]
        marchive = tarfile.open(fileobj=archive.extractfile(name))
        for mname in marchive.getnames():
            if re.match(r'{}_[0-9]+\.JPEG'.format(category), mname):
                images0.append((synsets_id[category], name, marchive, mname))
            else:
                print('Invalid file {} includes in tar file'.format(mname))
                exit(-1)
        pbar.update(1)
    pbar.close()

    # Thinning
    images = []
    for i, image in enumerate(images0):
        if i % args.thinning == 0:
            images.append(image)

    def _load_func(index):
        y, name, marchive, mname = images[index]
        im = scipy.misc.imread(marchive.extractfile(mname), mode='RGB')
        x = _resize_image(im, args.width, args.height, args.mode == 'padding')
        return x, np.array([y - 1]).astype(np.int32)

    from nnabla.utils.data_source import DataSourceWithFileCache
    from nnabla.utils.data_source_implements import SimpleDataSource
    from nnabla.logger import logger

    logger.info('Num of data : {}'.format(len(images)))
    shuffle = True
    if args.shuffle == 'False':
        shuffle = False
    source = SimpleDataSource(_load_func, len(images), shuffle, rng=None)
    DataSourceWithFileCache(source, cache_dir=output, shuffle=args.shuffle)
コード例 #6
0
nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size',
                  '{}'.format(args.cache_size))
nnabla_config.set('DATA_ITERATOR', 'cache_file_format', '.h5')

HERE = os.path.dirname(__file__)
nnabla_examples_root = os.path.join(HERE, '../../../../nnabla-examples')
mnist_examples_root = os.path.realpath(
    os.path.join(nnabla_examples_root, 'mnist-collection'))
sys.path.append(mnist_examples_root)

from mnist_data import MnistDataSource
mnist_training_cache = args.output + '/mnist_training.cache'
if not os.path.exists(mnist_training_cache):
    os.makedirs(mnist_training_cache)
DataSourceWithFileCache(data_source=MnistDataSource(train=True,
                                                    shuffle=False,
                                                    rng=None),
                        cache_dir=mnist_training_cache,
                        shuffle=False,
                        rng=None)
mnist_test_cache = args.output + '/mnist_test.cache'
if not os.path.exists(mnist_test_cache):
    os.makedirs(mnist_test_cache)
DataSourceWithFileCache(data_source=MnistDataSource(train=False,
                                                    shuffle=False,
                                                    rng=None),
                        cache_dir=mnist_test_cache,
                        shuffle=False,
                        rng=None)
コード例 #7
0
def data_iterator(data_source,
                  batch_size,
                  rng=None,
                  with_memory_cache=True,
                  with_file_cache=False,
                  cache_dir=None,
                  epoch_begin_callbacks=[],
                  epoch_end_callbacks=[]):
    '''data_iterator
    Helper method to use :py:class:`DataSource <nnabla.utils.data_source.DataSource>`.

    You can use :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>` with your own :py:class:`DataSource <nnabla.utils.data_source.DataSource>`
    for easy implementation of data sources.

    For example,

    .. code-block:: python

        ds = YourOwnImplementationOfDataSource()
        batch = data_iterator(ds, batch_size)


    Args:
        data_source (:py:class:`DataSource <nnabla.utils.data_source.DataSource>`):
             Instance of DataSource class which provides data.
        batch_size (int): Batch size.
        rng (None or :obj:`numpy.random.RandomState`): Numpy random number
            generator.
        with_memory_cache (bool):
            If ``True``, use :py:class:`.data_source.DataSourceWithMemoryCache`
            to wrap ``data_source``. It is a good idea to set this as true unless
            data_source provides on-memory data.
            Default value is True.
        with_file_cache (bool):
            If ``True``, use :py:class:`.data_source.DataSourceWithFileCache`
            to wrap ``data_source``.
            If ``data_source`` is slow, enabling this option a is good idea.
            Default value is False.
        cache_dir (str):
            Location of file_cache.
            If this value is None, :py:class:`.data_source.DataSourceWithFileCache`
            creates file caches implicitly on temporary directory and erases them all
            when data_iterator is finished.
            Otherwise, :py:class:`.data_source.DataSourceWithFileCache` keeps created cache.
            Default is None.
        epoch_begin_callbacks (list of functions): An item is a function
            which takes an epoch index as an argument. These are called
            at the beginning of an epoch.
        epoch_end_callbacks (list of functions): An item is a function
            which takes an epoch index as an argument. These are called
            at the end of an epoch.

    Returns:
        :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>`:
            Instance of DataIterator.
    '''
    if with_file_cache:
        ds = DataSourceWithFileCache(data_source=data_source,
                                     cache_dir=cache_dir,
                                     shuffle=data_source.shuffle,
                                     rng=rng)
        if with_memory_cache:
            ds = DataSourceWithMemoryCache(ds,
                                           shuffle=ds.shuffle,
                                           rng=rng)
        return DataIterator(ds,
                            batch_size,
                            epoch_begin_callbacks=epoch_begin_callbacks,
                            epoch_end_callbacks=epoch_end_callbacks)
    else:
        if with_memory_cache:
            data_source = DataSourceWithMemoryCache(data_source,
                                                    shuffle=data_source.shuffle,
                                                    rng=rng)
        return DataIterator(data_source, batch_size,
                            epoch_begin_callbacks=epoch_begin_callbacks,
                            epoch_end_callbacks=epoch_end_callbacks)
コード例 #8
0
    def slice(self, rng, num_of_slices=None, slice_pos=None,
              slice_start=None, slice_end=None,
              cache_dir=None):
        '''
        Slices the data iterator so that newly generated data iterator has access to limited portion of the original data.

        Args:
            rng (numpy.random.RandomState): Random generator for Initializer.
            num_of_slices(int): Total number of slices to be made. Muts be used together with `slice_pos`.
            slice_pos(int): Position of the slice to be assigned to the new data iterator. Must be used together with `num_of_slices`.
            slice_start(int): Starting position of the range to be sliced into new data iterator. Must be used together with `slice_end`.
            slice_end(int) : End position of the range to be sliced into new data iterator. Must be used together with `slice_start`.
            cache_dir(str) : Directory to save cache files

        Example:

        .. code-block:: python

            from nnabla.utils.data_iterator import data_iterator_simple
            import numpy as np

            def load_func1(index):
                d = np.ones((2, 2)) * index
                return d

            di = data_iterator_simple(load_func1, 1000, batch_size=3)

            di_s1 = di.slice(None, num_of_slices=10, slice_pos=0)
            di_s2 = di.slice(None, num_of_slices=10, slice_pos=1)

            di_s3 = di.slice(None, slice_start=100, slice_end=200)
            di_s4 = di.slice(None, slice_start=300, slice_end=400)

        '''

        if num_of_slices is not None and slice_pos is not None and slice_start is None and slice_end is None:
            size = self._size // num_of_slices
            amount = self._size % num_of_slices
            slice_start = slice_pos * size
            if slice_pos < amount:
                slice_start += slice_pos
            else:
                slice_start += amount
            slice_end = slice_start + size
            if slice_end > self._size:
                slice_start -= (slice_end - self._size)
                slice_end = self._size

        elif num_of_slices is None and slice_pos is None and slice_start is not None and slice_end is not None:
            pass
        else:
            logger.critical(
                'You must specify position(num_of_slice and slice_pos) or range(slice_start and slice_end).')
            return None

        if cache_dir is None:
            ds = self._data_source
            while '_data_source' in dir(ds):
                if '_cache_dir' in dir(ds):
                    cache_dir = ds._cache_dir
                ds = ds._data_source

        if cache_dir is None:
            return DataIterator(
                DataSourceWithMemoryCache(
                    SlicedDataSource(
                        self._data_source,
                        self._data_source.shuffle,
                        slice_start=slice_start,
                        slice_end=slice_end),
                    shuffle=self._shuffle,
                    rng=rng),
                self._batch_size)
        else:
            return DataIterator(
                DataSourceWithMemoryCache(
                    DataSourceWithFileCache(
                        SlicedDataSource(
                            self._data_source,
                            self._data_source.shuffle,
                            slice_start=slice_start,
                            slice_end=slice_end),
                        cache_dir=cache_dir,
                        cache_file_name_prefix='cache_sliced_{:08d}_{:08d}'.format(
                            slice_start,
                            slice_end),
                        shuffle=self._shuffle,
                        rng=rng),
                    shuffle=self._shuffle,
                    rng=rng),
                self._batch_size)