Пример #1
0
def test_create_cache(test_data_csv_csv_20, test_data_csv_png_20,
                      input_file_fmt, cache_file_fmt, shuffle, normalize,
                      num_of_threads):
    if input_file_fmt == 'csv':
        csvfilename = test_data_csv_csv_20
    else:
        csvfilename = test_data_csv_png_20

    nnabla_config.set('DATA_ITERATOR', 'cache_file_format', cache_file_fmt)

    with create_temp_with_dir() as tmpdir:
        cc = CreateCache(csvfilename,
                         shuffle=shuffle,
                         num_of_threads=num_of_threads)
        cc.create(tmpdir, normalize=normalize)

        # get cache data source and csv file data source
        with closing(CacheDataSource(tmpdir)) as cache_source:
            csv_source = CsvDataSource(csvfilename, normalize=normalize)

            check_relative_csv_file_result(cache_file_fmt, csvfilename, tmpdir)

            assert cache_source.size == csv_source.size
            assert set(cache_source.variables) == set(csv_source.variables)

            if shuffle:
                with open(os.path.join(tmpdir, 'order.csv'), 'r') as f:
                    csv_source._order = [int(row[1]) for row in csv.reader(f)]

            for _ in range(cache_source.size):
                cache_data = associate_variables_and_data(cache_source)
                csv_data = associate_variables_and_data(csv_source)

                for v in cache_source.variables:
                    assert_allclose(cache_data[v], csv_data[v])
Пример #2
0
def conv_dataset_command(args):
    if type(args.num_of_threads) == int and args.num_of_threads <= 0:
        print("The numbers of threads [{}] must be positive integer.".format(
            args.num_of_threads))
        return False

    if os.path.exists(args.destination):
        if not args.force:
            print(
                'File or directory [{}] is exists use `-F` option to overwrite it.'
                .format(args.destination))
            return False
        elif os.path.isdir(args.destination):
            print('Overwrite destination [{}].'.format(args.destination))
            shutil.rmtree(args.destination, ignore_errors=True)
            os.mkdir(args.destination)
        else:
            print('Cannot overwrite file [{}] please delete it.'.format(
                args.destination))
            return False
    else:
        os.mkdir(args.destination)

    _, ext = os.path.splitext(args.source)
    if ext.lower() == '.csv':

        if os.path.exists(args.source):
            cc = CreateCache(args.source,
                             shuffle=args.shuffle,
                             num_of_threads=args.num_of_threads)
            print('Number of Data: {}'.format(cc._size))
            print('Shuffle:        {}'.format(cc._shuffle))
            print('Normalize:      {}'.format(args.normalize))
            cc.create(args.destination, normalize=args.normalize)
        else:
            with CsvDataSource(args.source,
                               shuffle=args.shuffle,
                               normalize=args.normalize) as source:
                _convert(args, source)

    elif ext.lower() == '.cache':
        with CacheDataSource(args.source,
                             shuffle=args.shuffle,
                             normalize=args.normalize) as source:
            _convert(args, source)
    else:
        print('Command `conv_dataset` only supports CSV or CACHE as source.')
    return True
Пример #3
0
def _create_dataset(uri, batch_size, shuffle, no_image_normalization,
                    cache_dir, overwrite_cache, create_cache_explicitly,
                    prepare_data_iterator, dataset_index):
    class Dataset:
        pass

    dataset = Dataset()
    dataset.uri = uri
    dataset.cache_dir = cache_dir
    dataset.normalize = not no_image_normalization

    comm = current_communicator()

    # use same random state for each process until slice is called
    # different random state is used for each dataset
    rng = numpy.random.RandomState(dataset_index)
    use_memory_cache = comm.size == 1 if comm else True

    if prepare_data_iterator:
        if cache_dir == '':
            cache_dir = None

        # Disable implicit cache creation when MPI is available.
        if cache_dir and (create_cache_explicitly or comm):
            cache_index = os.path.join(cache_dir, "cache_index.csv")
            if not os.path.exists(cache_index) or overwrite_cache:
                if single_or_rankzero():
                    logger.log(99, 'Creating cache data for "' + uri + '"')

                    try:
                        os.makedirs(cache_dir)
                    except OSError:
                        pass  # python2 does not support exists_ok arg

                    if os.path.exists(uri):
                        cc = CreateCache(uri, rng=rng, shuffle=shuffle)
                        cc.create(cache_dir, normalize=False)
                    else:
                        with data_iterator_csv_dataset(
                                uri,
                                batch_size,
                                shuffle,
                                rng=rng,
                                normalize=False,
                                cache_dir=cache_dir,
                                with_memory_cache=False) as di:
                            pass

            rng = numpy.random.RandomState(dataset_index)
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir,
                batch_size,
                shuffle,
                rng=rng,
                normalize=dataset.normalize,
                with_memory_cache=use_memory_cache))
        elif not cache_dir or overwrite_cache or not os.path.exists(
                cache_dir) or len(os.listdir(cache_dir)) == 0:
            if comm:
                logger.critical(
                    'Implicit cache creation does not support with MPI')
                import sys
                sys.exit(-1)
            else:
                if cache_dir:
                    try:
                        os.makedirs(cache_dir)
                    except OSError:
                        pass  # python2 does not support exists_ok arg
                dataset.data_iterator = (lambda: data_iterator_csv_dataset(
                    uri,
                    batch_size,
                    shuffle,
                    rng=rng,
                    normalize=dataset.normalize,
                    cache_dir=cache_dir))
        else:
            dataset.data_iterator = (lambda: data_iterator_cache(
                cache_dir,
                batch_size,
                shuffle,
                rng=rng,
                normalize=dataset.normalize,
                with_memory_cache=use_memory_cache))
    else:
        dataset.data_iterator = None
    return dataset