def test_create_cache(test_data_csv_csv_20, test_data_csv_png_20, input_file_fmt, cache_file_fmt, shuffle, normalize, num_of_threads): if input_file_fmt == 'csv': csvfilename = test_data_csv_csv_20 else: csvfilename = test_data_csv_png_20 nnabla_config.set('DATA_ITERATOR', 'cache_file_format', cache_file_fmt) with create_temp_with_dir() as tmpdir: cc = CreateCache(csvfilename, shuffle=shuffle, num_of_threads=num_of_threads) cc.create(tmpdir, normalize=normalize) # get cache data source and csv file data source with closing(CacheDataSource(tmpdir)) as cache_source: csv_source = CsvDataSource(csvfilename, normalize=normalize) check_relative_csv_file_result(cache_file_fmt, csvfilename, tmpdir) assert cache_source.size == csv_source.size assert set(cache_source.variables) == set(csv_source.variables) if shuffle: with open(os.path.join(tmpdir, 'order.csv'), 'r') as f: csv_source._order = [int(row[1]) for row in csv.reader(f)] for _ in range(cache_source.size): cache_data = associate_variables_and_data(cache_source) csv_data = associate_variables_and_data(csv_source) for v in cache_source.variables: assert_allclose(cache_data[v], csv_data[v])
def test_data_iterator_concat_datasets(test_data_csv_png_10, test_data_csv_png_20, batch_size, shuffle, use_thread, normalize, with_memory_cache, with_file_cache, with_context, stop_exhausted): nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3') nnabla_config.set( 'DATA_ITERATOR', 'data_source_buffer_max_size', '10000') nnabla_config.set( 'DATA_ITERATOR', 'data_source_buffer_num_of_data', '9') csvfilename_1 = test_data_csv_png_10 csvfilename_2 = test_data_csv_png_20 ds1 = CsvDataSource(csvfilename_1, shuffle=shuffle, normalize=normalize) ds2 = CsvDataSource(csvfilename_2, shuffle=shuffle, normalize=normalize) if with_context: with data_iterator_concat_datasets([ds1, ds2], batch_size=batch_size, shuffle=shuffle, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache, use_thread=use_thread, stop_exhausted=stop_exhausted) as di: check_data_iterator_concat_result( di, batch_size, normalize, ds1.size, ds2.size, stop_exhausted) else: di = data_iterator_concat_datasets([ds1, ds2], batch_size=batch_size, shuffle=shuffle, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache, use_thread=use_thread, stop_exhausted=stop_exhausted) check_data_iterator_concat_result( di, batch_size, normalize, ds1.size, ds2.size, stop_exhausted) di.close()
def conv_dataset_command(args): if os.path.exists(args.destination): if not args.force: print( 'File or directory [{}] is exists use `-F` option to overwrite it.' .format(args.destination)) return False elif os.path.isdir(args.destination): print('Overwrite destination [{}].'.format(args.destination)) shutil.rmtree(args.destination, ignore_errors=True) os.mkdir(args.destination) else: print('Cannot overwrite file [{}] please delete it.'.format( args.destination)) return False else: os.mkdir(args.destination) datasource = None _, ext = os.path.splitext(args.source) if ext.lower() == '.csv': with CsvDataSource(args.source, shuffle=args.shuffle, normalize=args.normalize) as source: _convert(args, source) elif ext.lower() == '.cache': with CacheDataSource(args.source, shuffle=args.shuffle, normalize=args.normalize) as source: _convert(args, source) else: print('Command `conv_dataset` only supports CSV or CACHE as source.') return True
def test_csv_data_source(test_data_csv_csv_10, test_data_csv_csv_20, size, shuffle): if size == 10: csvfilename = test_data_csv_csv_10 elif size == 20: csvfilename = test_data_csv_csv_20 cds = CsvDataSource(csvfilename, shuffle) cds.reset() order = [] for n in range(0, cds.size): data, label = cds.next() assert data[0][0] == label[0] order.append(int(round(data[0][0]))) if shuffle: assert not list(range(size)) == order assert list(range(size)) == sorted(order) else: assert list(range(size)) == order
def test_csv_data_source(test_data_csv_csv_10, test_data_csv_csv_20, size, shuffle): if size == 10: csvfilename = test_data_csv_csv_10 elif size == 20: csvfilename = test_data_csv_csv_20 cds = CsvDataSource(csvfilename, shuffle) cds.reset() order = [] for n in range(0, cds.size): data, label = cds.next() assert data[0][0] == label[0] order.append(int(round(data[0][0]))) if shuffle: assert not list(range(size)) == order assert list(range(size)) == sorted(order) else: assert list(range(size)) == order
def conv_dataset_command(args): if type(args.num_of_threads) == int and args.num_of_threads <= 0: print("The numbers of threads [{}] must be positive integer.".format( args.num_of_threads)) return False if os.path.exists(args.destination): if not args.force: print( 'File or directory [{}] is exists use `-F` option to overwrite it.' .format(args.destination)) return False elif os.path.isdir(args.destination): print('Overwrite destination [{}].'.format(args.destination)) shutil.rmtree(args.destination, ignore_errors=True) os.mkdir(args.destination) else: print('Cannot overwrite file [{}] please delete it.'.format( args.destination)) return False else: os.mkdir(args.destination) _, ext = os.path.splitext(args.source) if ext.lower() == '.csv': if os.path.exists(args.source): cc = CreateCache(args.source, shuffle=args.shuffle, num_of_threads=args.num_of_threads) print('Number of Data: {}'.format(cc._size)) print('Shuffle: {}'.format(cc._shuffle)) print('Normalize: {}'.format(args.normalize)) cc.create(args.destination, normalize=args.normalize) else: with CsvDataSource(args.source, shuffle=args.shuffle, normalize=args.normalize) as source: _convert(args, source) elif ext.lower() == '.cache': with CacheDataSource(args.source, shuffle=args.shuffle, normalize=args.normalize) as source: _convert(args, source) else: print('Command `conv_dataset` only supports CSV or CACHE as source.') return True
def test_concat_data_source(test_data_csv_csv_10, test_data_csv_csv_20, shuffle): data_list = [test_data_csv_csv_10, test_data_csv_csv_20] ds_list = [CsvDataSource(csvfilename, shuffle) for csvfilename in data_list] cds = ConcatDataSource(ds_list, shuffle) cds.reset() order = [] for n in range(0, cds.size): data, label = cds.next() assert data[0][0] == label[0] order.append(int(round(data[0][0]))) original_order = list(range(10)) + list(range(20)) if shuffle: assert not original_order == order assert sorted(original_order) == sorted(order) else: assert original_order == order