def test_create_cache(test_data_csv_csv_20, test_data_csv_png_20, input_file_fmt, cache_file_fmt, shuffle, normalize, num_of_threads): if input_file_fmt == 'csv': csvfilename = test_data_csv_csv_20 else: csvfilename = test_data_csv_png_20 nnabla_config.set('DATA_ITERATOR', 'cache_file_format', cache_file_fmt) with create_temp_with_dir() as tmpdir: cc = CreateCache(csvfilename, shuffle=shuffle, num_of_threads=num_of_threads) cc.create(tmpdir, normalize=normalize) # get cache data source and csv file data source with closing(CacheDataSource(tmpdir)) as cache_source: csv_source = CsvDataSource(csvfilename, normalize=normalize) check_relative_csv_file_result(cache_file_fmt, csvfilename, tmpdir) assert cache_source.size == csv_source.size assert set(cache_source.variables) == set(csv_source.variables) if shuffle: with open(os.path.join(tmpdir, 'order.csv'), 'r') as f: csv_source._order = [int(row[1]) for row in csv.reader(f)] for _ in range(cache_source.size): cache_data = associate_variables_and_data(cache_source) csv_data = associate_variables_and_data(csv_source) for v in cache_source.variables: assert_allclose(cache_data[v], csv_data[v])
def conv_dataset_command(args): if type(args.num_of_threads) == int and args.num_of_threads <= 0: print("The numbers of threads [{}] must be positive integer.".format( args.num_of_threads)) return False if os.path.exists(args.destination): if not args.force: print( 'File or directory [{}] is exists use `-F` option to overwrite it.' .format(args.destination)) return False elif os.path.isdir(args.destination): print('Overwrite destination [{}].'.format(args.destination)) shutil.rmtree(args.destination, ignore_errors=True) os.mkdir(args.destination) else: print('Cannot overwrite file [{}] please delete it.'.format( args.destination)) return False else: os.mkdir(args.destination) _, ext = os.path.splitext(args.source) if ext.lower() == '.csv': if os.path.exists(args.source): cc = CreateCache(args.source, shuffle=args.shuffle, num_of_threads=args.num_of_threads) print('Number of Data: {}'.format(cc._size)) print('Shuffle: {}'.format(cc._shuffle)) print('Normalize: {}'.format(args.normalize)) cc.create(args.destination, normalize=args.normalize) else: with CsvDataSource(args.source, shuffle=args.shuffle, normalize=args.normalize) as source: _convert(args, source) elif ext.lower() == '.cache': with CacheDataSource(args.source, shuffle=args.shuffle, normalize=args.normalize) as source: _convert(args, source) else: print('Command `conv_dataset` only supports CSV or CACHE as source.') return True
def _create_dataset(uri, batch_size, shuffle, no_image_normalization, cache_dir, overwrite_cache, create_cache_explicitly, prepare_data_iterator, dataset_index): class Dataset: pass dataset = Dataset() dataset.uri = uri dataset.cache_dir = cache_dir dataset.normalize = not no_image_normalization comm = current_communicator() # use same random state for each process until slice is called # different random state is used for each dataset rng = numpy.random.RandomState(dataset_index) use_memory_cache = comm.size == 1 if comm else True if prepare_data_iterator: if cache_dir == '': cache_dir = None # Disable implicit cache creation when MPI is available. if cache_dir and (create_cache_explicitly or comm): cache_index = os.path.join(cache_dir, "cache_index.csv") if not os.path.exists(cache_index) or overwrite_cache: if single_or_rankzero(): logger.log(99, 'Creating cache data for "' + uri + '"') try: os.makedirs(cache_dir) except OSError: pass # python2 does not support exists_ok arg if os.path.exists(uri): cc = CreateCache(uri, rng=rng, shuffle=shuffle) cc.create(cache_dir, normalize=False) else: with data_iterator_csv_dataset( uri, batch_size, shuffle, rng=rng, normalize=False, cache_dir=cache_dir, with_memory_cache=False) as di: pass rng = numpy.random.RandomState(dataset_index) dataset.data_iterator = (lambda: data_iterator_cache( cache_dir, batch_size, shuffle, rng=rng, normalize=dataset.normalize, with_memory_cache=use_memory_cache)) elif not cache_dir or overwrite_cache or not os.path.exists( cache_dir) or len(os.listdir(cache_dir)) == 0: if comm: logger.critical( 'Implicit cache creation does not support with MPI') import sys sys.exit(-1) else: if cache_dir: try: os.makedirs(cache_dir) except OSError: pass # python2 does not support exists_ok arg dataset.data_iterator = (lambda: data_iterator_csv_dataset( uri, batch_size, shuffle, rng=rng, normalize=dataset.normalize, cache_dir=cache_dir)) else: dataset.data_iterator = (lambda: data_iterator_cache( cache_dir, batch_size, shuffle, rng=rng, normalize=dataset.normalize, with_memory_cache=use_memory_cache)) else: dataset.data_iterator = None return dataset