Ejemplo n.º 1
0
def test_simple_data_source(test_data_csv_png_20, shuffle):
    src_data = []
    with open(test_data_csv_png_20) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(os.path.dirname(test_data_csv_png_20),
                                         values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, int(values[1])))

    def test_load_func(position):
        return src_data[position]

    size = len(src_data)
    ds = SimpleDataSource(test_load_func, size, shuffle=shuffle)
    order = []
    for i in range(ds.size):
        data, label = ds.next()
        assert data[0][0][0] == label
        order.append(label)
    if shuffle:
        assert not list(range(size)) == order
        assert list(range(size)) == sorted(order)
    else:
        assert list(range(size)) == order
Ejemplo n.º 2
0
def test_simple_data_source(test_data_csv_png_20, shuffle):
    src_data = []
    with open(test_data_csv_png_20) as f:
        for l in f.readlines():
            values = [x.strip() for x in l.split(',')]
            img_file_name = os.path.join(
                os.path.dirname(test_data_csv_png_20), values[0])
            if os.path.exists(img_file_name):
                with open(img_file_name, 'rb') as img_file:
                    d = load_image(img_file)
                    src_data.append((d, int(values[1])))

    def test_load_func(position):
        return src_data[position]

    size = len(src_data)
    ds = SimpleDataSource(test_load_func, size, shuffle=shuffle)
    order = []
    for i in range(ds.size):
        data, label = ds.next()
        assert data[0][0][0] == label
        order.append(label)
    if shuffle:
        assert not list(range(size)) == order
        assert list(range(size)) == sorted(order)
    else:
        assert list(range(size)) == order
Ejemplo n.º 3
0
def _create_validation_cache(archive, output, names, ground_truth, args):
    images0 = sorted(names)

    # Thinning
    images = []
    for i, image in enumerate(images0):
        if i % args.thinning == 0:
            images.append(image)

    def _load_func(index):
        y, name = ground_truth[index], images[index]
        im = imread(archive.extractfile(name), num_channels=3)
        x = _resize_image(im, args.width, args.height, args.mode == 'padding')
        return x, np.array([y]).astype(np.int32)

    from nnabla.utils.data_source import DataSourceWithFileCache
    from nnabla.utils.data_source_implements import SimpleDataSource
    from nnabla.logger import logger

    logger.info('Num of data : {}'.format(len(images)))
    shuffle = False
    if args.shuffle == 'True':
        shuffle = True
    source = SimpleDataSource(_load_func, len(images), shuffle, rng=None)
    DataSourceWithFileCache(source, cache_dir=output, shuffle=args.shuffle)
Ejemplo n.º 4
0
def _create_validation_cache(archive, output, names, ground_truth, args):
    # ILSVRC2012_devkit_t12/readme.txt
    #     The ground truth of the validation images is in
    #     data/ILSVRC2012_validation_ground_truth.txt, where each line contains
    #     one ILSVRC2012_ID for one image, in the ascending alphabetical order
    #     of the image file names.
    images0 = sorted(names)

    # Thinning
    images = []
    for i, image in enumerate(images0):
        if i % args.thinning == 0:
            images.append(image)

    def _load_func(index):
        y, name = ground_truth[index], images[index]
        im = imread(archive.extractfile(name))
        x = _resize_image(im, args.width, args.height, args.mode == 'padding')
        return x, np.array([y - 1]).astype(np.int32)

    from nnabla.utils.data_source import DataSourceWithFileCache
    from nnabla.utils.data_source_implements import SimpleDataSource
    from nnabla.logger import logger

    logger.info('Num of data : {}'.format(len(images)))
    shuffle = False
    if args.shuffle == 'True':
        shuffle = True
    source = SimpleDataSource(_load_func, len(images), shuffle, rng=None)
    DataSourceWithFileCache(source, cache_dir=output, shuffle=args.shuffle)
Ejemplo n.º 5
0
def _create_train_cache(archive, output, names, args):
    # Read label and wordnet_id
    wid2ind = np.loadtxt(fname=LABEL_WORDNETID, dtype=object, delimiter=",")

    def _get_label(wordnet_id):
        for item in wid2ind:
            if item[1] == wordnet_id:
                return item[0]

    images0 = []
    print("Count image in TAR")
    pbar = tqdm.tqdm(total=len(names), unit='%')
    for name in names:
        category = os.path.splitext(name)[0]
        marchive = tarfile.open(fileobj=archive.extractfile(name))
        for mname in marchive.getnames():
            if re.match(r'{}_[0-9]+\.JPEG'.format(category), mname):
                images0.append((_get_label(name[:9]), name, marchive, mname))
            else:
                print('Invalid file {} includes in tar file'.format(mname))
                exit(-1)
        pbar.update(1)
    pbar.close()

    # Thinning
    images = []
    for i, image in enumerate(images0):
        if i % args.thinning == 0:
            images.append(image)

    def _load_func(index):
        y, name, marchive, mname = images[index]
        im = imread(marchive.extractfile(mname), num_channels=3)
        x = _resize_image(im, args.width, args.height, args.mode == 'padding')
        return x, np.array([y]).astype(np.int32)

    from nnabla.utils.data_source import DataSourceWithFileCache
    from nnabla.utils.data_source_implements import SimpleDataSource
    from nnabla.logger import logger

    logger.info('Num of data : {}'.format(len(images)))
    shuffle = True
    if args.shuffle == 'False':
        shuffle = False
    source = SimpleDataSource(_load_func, len(images), shuffle, rng=None)
    DataSourceWithFileCache(
        source, cache_dir=output, shuffle=args.shuffle)
Ejemplo n.º 6
0
def _create_train_cache(archive, output, names, synsets_id, args):
    images0 = []
    print("Count image in TAR")
    pbar = tqdm.tqdm(total=len(names), unit='%')
    for name in names:
        category = os.path.splitext(name)[0]
        marchive = tarfile.open(fileobj=archive.extractfile(name))
        for mname in marchive.getnames():
            if re.match(r'{}_[0-9]+\.JPEG'.format(category), mname):
                images0.append((synsets_id[category], name, marchive, mname))
            else:
                print('Invalid file {} includes in tar file'.format(mname))
                exit(-1)
        pbar.update(1)
    pbar.close()

    # Thinning
    images = []
    for i, image in enumerate(images0):
        if i % args.thinning == 0:
            images.append(image)

    def _load_func(index):
        y, name, marchive, mname = images[index]
        im = scipy.misc.imread(marchive.extractfile(mname), mode='RGB')
        x = _resize_image(im, args.width, args.height, args.mode == 'padding')
        return x, np.array([y - 1]).astype(np.int32)

    from nnabla.utils.data_source import DataSourceWithFileCache
    from nnabla.utils.data_source_implements import SimpleDataSource
    from nnabla.logger import logger

    logger.info('Num of data : {}'.format(len(images)))
    shuffle = True
    if args.shuffle == 'False':
        shuffle = False
    source = SimpleDataSource(_load_func, len(images), shuffle, rng=None)
    DataSourceWithFileCache(source, cache_dir=output, shuffle=args.shuffle)
Ejemplo n.º 7
0
def data_iterator_simple(load_func,
                         num_examples,
                         batch_size,
                         shuffle=False,
                         rng=None,
                         with_memory_cache=True,
                         with_file_cache=True,
                         cache_dir=None,
                         epoch_begin_callbacks=[],
                         epoch_end_callbacks=[]):
    """A generator that ``yield`` s minibatch data as a tuple, as defined in ``load_func`` .
    It can unlimitedly yield minibatches at your request, queried from the provided data.

    Args:
        load_func (function): Takes a single argument `i`, an index of an
            example in your dataset to be loaded, and returns a tuple of data.
            Every call by any index `i` must return a tuple of arrays with
            the same shape.
        num_examples (int): Number of examples in your dataset. Random sequence
            of indexes is generated according to this number.
        batch_size (int): Size of data unit.
        shuffle (bool):
             Indicates whether the dataset is shuffled or not.
             Default value is False.
        rng (None or :obj:`numpy.random.RandomState`): Numpy random number
            generator.
        with_memory_cache (bool):
            If ``True``, use :py:class:`.data_source.DataSourceWithMemoryCache`
            to wrap ``data_source``. It is a good idea to set this as true unless
            data_source provides on-memory data.
            Default value is True.
        with_file_cache (bool):
            If ``True``, use :py:class:`.data_source.DataSourceWithFileCache`
            to wrap ``data_source``.
            If ``data_source`` is slow, enabling this option a is good idea.
            Default value is False.
        cache_dir (str):
            Location of file_cache.
            If this value is None, :py:class:`.data_source.DataSourceWithFileCache`
            creates file caches implicitly on temporary directory and erases them all
            when data_iterator is finished.
            Otherwise, :py:class:`.data_source.DataSourceWithFileCache` keeps created cache.
            Default is None.
        epoch_begin_callbacks (list of functions): An item is a function
            which takes an epoch index as an argument. These are called
            at the beginning of an epoch.
        epoch_end_callbacks (list of functions): An item is a function
            which takes an epoch index as an argument. These are called
            at the end of an epoch.


    Returns:
        :py:class:`DataIterator <nnabla.utils.data_iterator.DataIterator>`:
            Instance of DataIterator.


    Here is an example of `load_func` which returns an image and a label of a
    classification dataset.

    .. code-block:: python

        import numpy as np
        from nnabla.utils.image_utils import imread
        image_paths = load_image_paths()
        labels = load_labels()
        def my_load_func(i):
            '''
            Returns:
                image: c x h x w array
                label: 0-shape array
            '''
            img = imread(image_paths[i]).astype('float32')
            return np.rollaxis(img, 2), np.array(labels[i])


    """
    return data_iterator(SimpleDataSource(load_func,
                                          num_examples,
                                          shuffle=shuffle,
                                          rng=rng),
                         batch_size=batch_size,
                         with_memory_cache=with_memory_cache,
                         with_file_cache=with_file_cache,
                         cache_dir=cache_dir,
                         epoch_begin_callbacks=epoch_begin_callbacks,
                         epoch_end_callbacks=epoch_end_callbacks)