def _create_dataset(uri, batch_size, shuffle, no_image_normalization, cache_dir, overwrite_cache, create_cache_explicitly, prepare_data_iterator): class Dataset: pass dataset = Dataset() dataset.uri = uri dataset.normalize = not no_image_normalization if prepare_data_iterator: if cache_dir == '': cache_dir = None if cache_dir and create_cache_explicitly: if not os.path.exists(cache_dir) or len(os.listdir(cache_dir)) == 0 or overwrite_cache: if not os.path.exists(cache_dir): os.mkdir(cache_dir) logger.log(99, 'Creating cache data for "' + uri + '"') with data_iterator_csv_dataset(uri, batch_size, shuffle, normalize=False, cache_dir=cache_dir) as di: index = 0 while index < di.size: progress('', (1.0 * di.position) / di.size) di.next() index += batch_size dataset.data_iterator = (lambda: data_iterator_cache( cache_dir, batch_size, shuffle, normalize=dataset.normalize)) elif not cache_dir or overwrite_cache or not os.path.exists(cache_dir) or len(os.listdir(cache_dir)) == 0: if cache_dir and not os.path.exists(cache_dir): os.mkdir(cache_dir) dataset.data_iterator = (lambda: data_iterator_csv_dataset( uri, batch_size, shuffle, normalize=dataset.normalize, cache_dir=cache_dir)) else: dataset.data_iterator = (lambda: data_iterator_cache( cache_dir, batch_size, shuffle, normalize=dataset.normalize)) else: dataset.data_iterator = None return dataset
def _create_dataset(uri, batch_size, shuffle, no_image_normalization, cache_dir, overwrite_cache, create_cache_explicitly, prepare_data_iterator): class Dataset: pass dataset = Dataset() dataset.uri = uri dataset.normalize = not no_image_normalization if prepare_data_iterator: if cache_dir == '': cache_dir = None if cache_dir and create_cache_explicitly: if not os.path.exists(cache_dir) or overwrite_cache: if not os.path.exists(cache_dir): os.mkdir(cache_dir) logger.info('Creating cache data for "' + uri + '"') with data_iterator_csv_dataset(uri, batch_size, shuffle, normalize=False, cache_dir=cache_dir) as di: index = 0 while index < di.size: progress('', (1.0 * di.position) / di.size) di.next() index += batch_size dataset.data_iterator = (lambda: data_iterator_cache( cache_dir, batch_size, shuffle, normalize=dataset.normalize)) elif not cache_dir or overwrite_cache or not os.path.exists(cache_dir): if cache_dir and not os.path.exists(cache_dir): os.mkdir(cache_dir) dataset.data_iterator = (lambda: data_iterator_csv_dataset( uri, batch_size, shuffle, normalize=dataset.normalize, cache_dir=cache_dir)) else: dataset.data_iterator = (lambda: data_iterator_cache( cache_dir, batch_size, shuffle, normalize=dataset.normalize)) else: dataset.data_iterator = None return dataset
def test_data_iterator_csv_dataset(test_data_csv_png_10, test_data_csv_png_20, size, batch_size, shuffle, normalize, with_memory_cache, with_file_cache, with_context): nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3') nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_max_size', '10000') nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_num_of_data', '9') if size == 10: csvfilename = test_data_csv_png_10 elif size == 20: csvfilename = test_data_csv_png_20 logger.info(csvfilename) if with_context: with data_iterator_csv_dataset(uri=csvfilename, batch_size=batch_size, shuffle=shuffle, normalize=normalize, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache) as di: check_data_iterator_result(di, batch_size, shuffle, normalize) else: di = data_iterator_csv_dataset(uri=csvfilename, batch_size=batch_size, shuffle=shuffle, normalize=normalize, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache) check_data_iterator_result(di, batch_size, shuffle, normalize) di.close()
def _create_dataset(uri, batch_size, shuffle, no_image_normalization, cache_dir, overwrite_cache, create_cache_explicitly, prepare_data_iterator): class Dataset: pass dataset = Dataset() dataset.uri = uri dataset.normalize = not no_image_normalization comm = current_communicator() # use same random state for each process until slice is called rng = numpy.random.RandomState(0) use_memory_cache = comm.size == 1 if comm else True if prepare_data_iterator: if cache_dir == '': cache_dir = None # Disable implicit cache creation when MPI is available. if cache_dir and (create_cache_explicitly or comm): cache_index = os.path.join(cache_dir, "cache_index.csv") if not os.path.exists(cache_index) or overwrite_cache: if single_or_rankzero(): logger.log(99, 'Creating cache data for "' + uri + '"') try: os.makedirs(cache_dir) except OSError: pass # python2 does not support exists_ok arg with data_iterator_csv_dataset(uri, batch_size, shuffle, rng=rng, normalize=False, cache_dir=cache_dir, with_memory_cache=False) as di: pass rng = numpy.random.RandomState(0) dataset.data_iterator = (lambda: data_iterator_cache( cache_dir, batch_size, shuffle, rng=rng, normalize=dataset.normalize, with_memory_cache=use_memory_cache)) elif not cache_dir or overwrite_cache or not os.path.exists(cache_dir) or len(os.listdir(cache_dir)) == 0: if comm: logger.critical( 'Implicit cache creation does not support with MPI') import sys sys.exit(-1) else: if cache_dir: try: os.makedirs(cache_dir) except OSError: pass # python2 does not support exists_ok arg dataset.data_iterator = (lambda: data_iterator_csv_dataset( uri, batch_size, shuffle, rng=rng, normalize=dataset.normalize, cache_dir=cache_dir)) else: dataset.data_iterator = (lambda: data_iterator_cache( cache_dir, batch_size, shuffle, rng=rng, normalize=dataset.normalize, with_memory_cache=use_memory_cache)) else: dataset.data_iterator = None return dataset
def _load_dataset(self, dataset_path, batch_size=100, shuffle=False): if os.path.isfile(dataset_path): logger.info("Load a dataset from {}.".format(dataset_path)) return data_iterator_csv_dataset(dataset_path, batch_size, shuffle=shuffle) return None
def loadData(batch_size): cache_dir = "./cache" if os.path.isdir(cache_dir) == False: os.mkdir(cache_dir) dataset = data_iterator_csv_dataset( "../AIStudy.BoardGame/experience.csv", batch_size, shuffle=True, normalize=True, cache_dir=cache_dir) else: dataset = data_iterator_cache(cache_dir, batch_size, shuffle=True, normalize=True) variables = dataset.variables print(variables) """ s0Index = variables.index('s0') print(("index(s0)={}").format(s0Index)) for n in range(1000): data = dataset.next() if n==0: print(("shape(s0)={}").format(data[s0Index].shape)) print(("epoch={},position={},size={}, data_size={}").format(dataset.epoch, dataset.position,dataset.size, len(data[0]))) """ return dataset
def test_data_iterator_csv_dataset(test_data_csv_png_10, test_data_csv_png_20, size, batch_size, shuffle, normalize, with_memory_cache, with_file_cache, with_context): nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3') nnabla_config.set( 'DATA_ITERATOR', 'data_source_buffer_max_size', '10000') nnabla_config.set( 'DATA_ITERATOR', 'data_source_buffer_num_of_data', '9') if size == 10: csvfilename = test_data_csv_png_10 elif size == 20: csvfilename = test_data_csv_png_20 logger.info(csvfilename) if with_context: with data_iterator_csv_dataset(uri=csvfilename, batch_size=batch_size, shuffle=shuffle, normalize=normalize, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache) as di: check_data_iterator_result(di, batch_size, shuffle, normalize) else: di = data_iterator_csv_dataset(uri=csvfilename, batch_size=batch_size, shuffle=shuffle, normalize=normalize, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache) check_data_iterator_result(di, batch_size, shuffle, normalize) di.close()
def common_forward(info, forward_func): batch_size = 1 class ForwardConfig: pass class Args: pass args = Args() config = ForwardConfig if hasattr(info, 'global_config'): config.global_config = info.global_config config.executors = info.executors.values() config.networks = [] for e in config.executors: if e.network.name in info.networks.keys(): config.networks.append(info.networks[e.network.name]) else: assert False, "{} is not found.".format(e.network.name) normalize = True for d in info.datasets.values(): args.dataset = d.uri normalize = d.normalize break for e in config.executors: normalize = normalize and not e.no_image_normalization data_iterator = (lambda: data_iterator_csv_dataset( uri=args.dataset, batch_size=config.networks[0].batch_size, shuffle=False, normalize=normalize, with_memory_cache=False, with_file_cache=False)) result = [] with data_iterator() as di: index = 0 while index < di.size: data = di.next() avg = forward_func( args, index, config, data, di.variables) index += len(avg[0]) result.append(avg[0]) return np.array(result)
def forward_command(args): callback.update_status(args) configure_progress(os.path.join(args.outdir, 'progress.txt')) files = [] files.append(args.config) if args.param: files.append(args.param) batch_size = args.batch_size if batch_size < 1: batch_size = None class ForwardConfig: pass config = ForwardConfig info = load.load(files, prepare_data_iterator=False, batch_size=batch_size) config.global_config = info.global_config config.executors = info.executors.values() config.networks = [] for e in config.executors: if e.network.name in info.networks.keys(): config.networks.append(info.networks[e.network.name]) else: logger.critical('Network {} is not found.'.format( config.executor.network.name)) return False normalize = True for d in info.datasets.values(): if d.uri == args.dataset or d.cache_dir == args.dataset: normalize = d.normalize for e in config.executors: normalize = normalize and not e.no_image_normalization orders = {} # With CSV if os.path.splitext(args.dataset)[1] == '.csv': data_iterator = (lambda: data_iterator_csv_dataset( uri=args.dataset, batch_size=config.networks[0].batch_size, shuffle=False, normalize=normalize, with_memory_cache=False, with_file_cache=False)) # load dataset as csv filereader = FileReader(args.dataset) with filereader.open(textmode=True, encoding='utf-8-sig') as f: rows = [row for row in csv.reader(f)] row0 = rows.pop(0) if args.replace_path: root_path = os.path.dirname(args.dataset) root_path = os.path.abspath(root_path.replace('/|\\', os.path.sep)) else: root_path = '.' rows = [row for row in rows if len(row)] rows = list( map( lambda row: list( map( lambda i, x: x if row0[i][0] == '#' or is_float( x) else compute_full_path(root_path, x), range(len(row)), row)), rows)) for i in range(len(rows)): orders[i] = i # With Cache elif os.path.splitext(args.dataset)[1] == '.cache': data_iterator = (lambda: data_iterator_cache(uri=args.dataset, batch_size=config. networks[0].batch_size, shuffle=False, normalize=normalize)) # Get original CSV original_csv = os.path.join(args.dataset, 'original.csv') try: # load dataset as csv filereader = FileReader(original_csv) with filereader.open(textmode=True, encoding='utf-8-sig') as f: rows = [row for row in csv.reader(f)] row0 = rows.pop(0) root_path = '.' rows = list( map( lambda row: list( map( lambda x: x if is_float(x) else compute_full_path( root_path, x), row)), rows)) except: print('Cannot open', original_csv) pass # Get original Data order. order_csv = os.path.join(args.dataset, 'order.csv') try: filereader = FileReader(order_csv) with filereader.open(textmode=True) as f: for original, shuffled in [[int(x) for x in row] for row in csv.reader(f)]: orders[original] = shuffled except: print('Cannot open', order_csv) for i in range(len(rows)): orders[i] = i else: print('Unsupported extension "{}" in "{}".'.format( os.path.splitext(args.dataset)[1], args.dataset)) callback.update_status(('data.max', len(rows))) callback.update_status(('data.current', 0)) callback.update_status('processing', True) result_csv_filename = os.path.join(args.outdir, args.outfile) with open(result_csv_filename, 'w', encoding='utf-8') as f: writer = csv.writer(f, lineterminator='\n') with data_iterator() as di: index = 0 while index < di.size: data = di.next() result, outputs = _forward(args, index, config, data, di.variables) if index == 0: for name, dim in zip(result.names, result.dims): if dim == 1: if e.repeat_evaluation_type == "std": name = "Uncertainty(Std)" row0.append(name) else: for d in range(dim): row0.append(name + '__' + str(d)) writer.writerow(row0) for i, output in enumerate(outputs): if index + i < len(rows): import copy row = copy.deepcopy(rows[orders[index + i]]) row.extend(output) writer.writerow(row) index += len(outputs) callback.update_status(('data.current', min([index, len(rows)]))) callback.update_forward_time() callback.update_status() logger.log( 99, 'data {} / {}'.format(min([index, len(rows)]), len(rows))) callback.process_evaluation_result(args.outdir, result_csv_filename) logger.log(99, 'Forward Completed.') progress(None) callback.update_status(('output_result.csv_header', ','.join(row0))) callback.update_status(('output_result.column_num', len(row0))) callback.update_status(('output_result.data_num', len(rows))) callback.update_status('finished') return True
'vision', 'imagenet')) from tiny_imagenet_data import data_iterator_tiny_imagenet with data_iterator_tiny_imagenet(args.batch_size, 'train') as di: test_data_iterator(di, args) elif args.uri == 'TINY_IMAGENET_VAL': sys.path.append( os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'vision', 'imagenet')) from tiny_imagenet_data import data_iterator_tiny_imagenet with data_iterator_tiny_imagenet(args.batch_size, 'val') as di: test_data_iterator(di, args) else: if os.path.splitext(args.uri)[1].lower() == '.cache': from nnabla.utils.data_iterator import data_iterator_cache with data_iterator_cache(uri=args.uri, batch_size=args.batch_size, shuffle=args.shuffle, with_memory_cache=args.memory_cache, normalize=args.normalize) as di: test_data_iterator(di, args) else: from nnabla.utils.data_iterator import data_iterator_csv_dataset with data_iterator_csv_dataset(uri=args.uri, batch_size=args.batch_size, shuffle=args.shuffle, normalize=args.normalize, with_memory_cache=args.memory_cache, with_file_cache=args.file_cache, cache_dir=args.output) as di: test_data_iterator(di, args)
def forward_command(args): configure_progress(os.path.join(args.outdir, 'progress.txt')) files = [] files.append(args.config) if args.param: files.append(args.param) batch_size = args.batch_size if batch_size < 1: batch_size = None class ForwardConfig: pass config = ForwardConfig info = load.load(files, prepare_data_iterator=False, batch_size=batch_size) config.global_config = info.global_config config.executors = info.executors.values() config.networks = [] for e in config.executors: if e.network.name in info.networks.keys(): config.networks.append(info.networks[e.network.name]) else: logger.critical('Network {} is not found.'.format( config.executor.network.name)) return False normalize = True for d in info.datasets.values(): if d.uri == args.dataset: normalize = d.normalize for e in config.executors: normalize = normalize and not e.no_image_normalization data_iterator = (lambda: data_iterator_csv_dataset(uri=args.dataset, batch_size=config. networks[0].batch_size, shuffle=False, normalize=normalize, with_memory_cache=False, with_file_cache=False)) # load dataset as csv filereader = FileReader(args.dataset) with filereader.open(textmode=True) as f: rows = [row for row in csv.reader(f)] row0 = rows.pop(0) root_path = os.path.dirname(args.dataset) root_path = os.path.abspath(root_path.replace('/|\\', os.path.sep)) rows = list( map( lambda row: list( map( lambda x: x if is_float(x) else compute_full_path(root_path, x), row)), rows)) with open(os.path.join(args.outdir, 'output_result.csv'), 'w') as f: writer = csv.writer(f, lineterminator='\n') with data_iterator() as di: index = 0 while index < di.size: data = di.next() result, outputs = _forward(args, index, config, data, di.variables) if index == 0: for name, dim in zip(result.names, result.dims): if dim == 1: row0.append(name) else: for d in range(dim): row0.append(name + '__' + str(d)) writer.writerow(row0) for i, output in enumerate(outputs): if index + i < len(rows): import copy row = copy.deepcopy(rows[index + i]) row.extend(output) writer.writerow(row) index += len(outputs) logger.log( 99, 'data {} / {}'.format(min([index, len(rows)]), len(rows))) logger.log(99, 'Forward Completed.') progress(None) return True
def get_data_iterator_and_num_class(args): """ Get Data_iterator for training and test data set. Also, obtain class / category information from data. """ if args.train_csv: from nnabla.utils.data_iterator import data_iterator_csv_dataset data_iterator = data_iterator_csv_dataset if args.test_csv: assert os.path.isfile( args.test_csv), "csv file for test not found." # check the number of the classes / categories with open(args.train_csv, "r") as f: csv_data_train = f.readlines()[1:] # line 1:"x:image,y:label" classes_train = { line.split(",")[-1].strip() for line in csv_data_train } with open(args.test_csv, "r") as f: # first line:"x:image,y:label" csv_data_test = f.readlines()[1:] classes_test = { line.split(",")[-1].strip() for line in csv_data_test } classes_train.update(classes_test) num_class = len(classes_train) data_iterator_train = data_iterator_csv_dataset(args.train_csv, args.batch_size, args.shuffle, normalize=False) data_iterator_valid = data_iterator_csv_dataset(args.test_csv, args.batch_size, args.shuffle, normalize=False) else: print("No csv file for test given. So split the training data") assert isintance(args.ratio, float), "ratio must be in (0.0, 1.0)" # check the number of the classes / categories with open(args.train_csv, "r") as f: # first line is "x:image,y:label" csv_data_train = f.readlines()[1:] all_classes = { line.split(",")[-1].strip() for line in csv_data_train } num_class = len(all_classes) all_data = data_iterator_csv_dataset(args.train_csv, args.batch_size, args.shuffle, normalize=False) num_samples = all_data.size num_train_samples = int(args.ratio * num_samples) data_iterator_train = all_data.slice(rng=None, slice_start=0, slice_end=num_train_samples) data_iterator_valid = all_data.slice(rng=None, slice_start=num_train_samples, slice_end=num_samples) else: # use caltech101 data like tutorial from caltech101_data import data_iterator_caltech101 assert isintance(args.ratio, float), "ratio must be in (0.0, 1.0)" data_iterator = data_iterator_caltech101 num_class = 101 # pre-defined (excluding background class) all_data = data_iterator(args.batch_size, width=args.width, height=args.height) num_samples = all_data.size num_train_samples = int(args.ratio * num_samples) data_iterator_train = all_data.slice(rng=None, slice_start=0, slice_end=num_train_samples) data_iterator_valid = all_data.slice(rng=None, slice_start=num_train_samples, slice_end=num_samples) print("training images: {}".format(data_iterator_train.size)) print("validation images: {}".format(data_iterator_valid.size)) print("{} categories included.".format(num_class)) return data_iterator_train, data_iterator_valid, num_class
def forward_command(args): configure_progress(os.path.join(args.outdir, 'progress.txt')) files = [] files.append(args.config) if args.param: files.append(args.param) class ForwardConfig: pass config = ForwardConfig info = load.load(files, prepare_data_iterator=False) config.global_config = info.global_config config.executors = info.executors.values() config.networks = [] for e in config.executors: if e.network.name in info.networks.keys(): config.networks.append(info.networks[e.network.name]) else: logger.critical('Network {} does not found.'.format( config.executor.network.name)) return normalize = True for d in info.datasets.values(): if d.uri == args.dataset: normalize = d.normalize data_iterator = (lambda: data_iterator_csv_dataset( args.dataset, config.networks[0].batch_size, False, normalize=normalize)) # load dataset as csv with open(args.dataset, 'rt') as f: rows = [row for row in csv.reader(f)] row0 = rows.pop(0) root_path = os.path.dirname(args.dataset) root_path = os.path.abspath(root_path.replace('/|\\', os.path.sep)) rows = list(map(lambda row: list(map(lambda x: x if is_float( x) else compute_full_path(root_path, x), row)), rows)) with data_iterator() as di: index = 0 while index < di.size: data = di.next() result, outputs = forward(args, index, config, data, di.variables) if index == 0: for name, dim in zip(result.names, result.dims): if dim == 1: row0.append(name) else: for d in range(dim): row0.append(name + '__' + str(d)) for i, output in enumerate(outputs): if index + i < len(rows): rows[index + i].extend(output) index += len(outputs) logger.log( 99, 'data {} / {}'.format(min([index, len(rows)]), len(rows))) with open(os.path.join(args.outdir, 'output_result.csv'), 'w') as f: writer = csv.writer(f, lineterminator='\n') writer.writerow(row0) writer.writerows(rows) logger.log(99, 'Forward Completed.') progress(None)
return args if __name__ == '__main__': config = get_args() logger.info("Running in %s" % config.context) seed(0) ctx = get_extension_context(config.context, device_id=config.device_id) nn.set_default_context(ctx) nn.clear_parameters() net = MLP(config) if config.process == 'train': net.train() elif config.process == 'evaluate': net.evaluate() elif config.process == 'infer': net.init_for_infer() if os.path.isfile(config.evaluation_dataset_path): logger.info("Load a dataset from {}.".format( config.evaluation_dataset_path)) edata = data_iterator_csv_dataset(config.evaluation_dataset_path, 1, shuffle=False) for i in range(edata.size): data = edata.next() x_d = data[0] result = net.infer(x_d) print("inference result = {}".format(result))
def forward_command(args): configure_progress(os.path.join(args.outdir, 'progress.txt')) files = [] files.append(args.config) if args.param: files.append(args.param) class ForwardConfig: pass config = ForwardConfig info = load.load(files, prepare_data_iterator=False) config.global_config = info.global_config config.executors = info.executors.values() config.networks = [] for e in config.executors: if e.network.name in info.networks.keys(): config.networks.append(info.networks[e.network.name]) else: logger.critical('Network {} does not found.'.format( config.executor.network.name)) return normalize = True for d in info.datasets.values(): if d.uri == args.dataset: normalize = d.normalize data_iterator = (lambda: data_iterator_csv_dataset( args.dataset, config.networks[0].batch_size, False, padding=True, normalize=normalize)) # load dataset as csv with open(args.dataset, 'rt') as f: rows = [row for row in csv.reader(f)] row0 = rows.pop(0) root_path = os.path.dirname(args.dataset) root_path = os.path.abspath(root_path.replace('/|\\', os.path.sep)) rows = map(lambda row: map(lambda x: x if is_float( x) else compute_full_path(root_path, x), row), rows) with data_iterator() as di: index = 0 while index < di.size: data = di.next() result, outputs = forward(args, index, config, data, di.variables) if index == 0: for name, dim in zip(result.names, result.dims): if dim == 1: row0.append(name) else: for d in range(dim): row0.append(name + '__' + str(d)) for i, output in enumerate(outputs): if index + i < len(rows): rows[index + i].extend(output) index += len(outputs) logger.log( 99, 'data {} / {}'.format(min([index, len(rows)]), len(rows))) with open(os.path.join(args.outdir, 'output_result.csv'), 'w') as f: writer = csv.writer(f, lineterminator='\n') writer.writerow(row0) writer.writerows(rows) logger.log(99, 'Forward Completed.') progress(None)
def test_data_iterator_csv_dataset(test_data_csv_png_10, test_data_csv_png_20, size, batch_size, shuffle, use_thread, normalize, with_memory_cache, with_file_cache, with_context, stop_exhausted): nnabla_config.set('DATA_ITERATOR', 'data_source_file_cache_size', '3') nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_max_size', '10000') nnabla_config.set('DATA_ITERATOR', 'data_source_buffer_num_of_data', '9') if size == 10: csvfilename = test_data_csv_png_10 elif size == 20: csvfilename = test_data_csv_png_20 logger.info(csvfilename) main_thread = threading.current_thread().ident expect_epoch = [0] def end_epoch(epoch): if batch_size // size == 0: assert epoch == expect_epoch[0], "Failed for end epoch check" else: print(f"E: {epoch} <--> {expect_epoch[0]}") assert threading.current_thread( ).ident == main_thread, "Failed for thread checking" def begin_epoch(epoch): if batch_size // size == 0: assert epoch == expect_epoch[0], "Failed for begin epoch check" else: print(f"B: {epoch} <--> {expect_epoch[0]}") assert threading.current_thread( ).ident == main_thread, "Failed for thread checking" if with_context: with data_iterator_csv_dataset(uri=csvfilename, batch_size=batch_size, shuffle=shuffle, normalize=normalize, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache, use_thread=use_thread, stop_exhausted=stop_exhausted) as di: di.register_epoch_end_callback(begin_epoch) di.register_epoch_end_callback(end_epoch) check_data_iterator_result(di, batch_size, shuffle, normalize, stop_exhausted, expect_epoch) else: di = data_iterator_csv_dataset(uri=csvfilename, batch_size=batch_size, shuffle=shuffle, normalize=normalize, with_memory_cache=with_memory_cache, with_file_cache=with_file_cache, use_thread=use_thread, stop_exhausted=stop_exhausted) di.register_epoch_end_callback(begin_epoch) di.register_epoch_end_callback(end_epoch) check_data_iterator_result(di, batch_size, shuffle, normalize, stop_exhausted, expect_epoch) di.close()