def test_simple_data_source(test_data_csv_png_20, shuffle): src_data = [] with open(test_data_csv_png_20) as f: for l in f.readlines(): values = [x.strip() for x in l.split(',')] img_file_name = os.path.join(os.path.dirname(test_data_csv_png_20), values[0]) if os.path.exists(img_file_name): with open(img_file_name, 'rb') as img_file: d = load_image(img_file) src_data.append((d, int(values[1]))) def test_load_func(position): return src_data[position] size = len(src_data) ds = SimpleDataSource(test_load_func, size, shuffle=shuffle) order = [] for i in range(ds.size): data, label = ds.next() assert data[0][0][0] == label order.append(label) if shuffle: assert not list(range(size)) == order assert list(range(size)) == sorted(order) else: assert list(range(size)) == order
def test_simple_data_source(test_data_csv_png_20, shuffle): src_data = [] with open(test_data_csv_png_20) as f: for l in f.readlines(): values = [x.strip() for x in l.split(',')] img_file_name = os.path.join( os.path.dirname(test_data_csv_png_20), values[0]) if os.path.exists(img_file_name): with open(img_file_name, 'rb') as img_file: d = load_image(img_file) src_data.append((d, int(values[1]))) def test_load_func(position): return src_data[position] size = len(src_data) ds = SimpleDataSource(test_load_func, size, shuffle=shuffle) order = [] for i in range(ds.size): data, label = ds.next() assert data[0][0][0] == label order.append(label) if shuffle: assert not list(range(size)) == order assert list(range(size)) == sorted(order) else: assert list(range(size)) == order
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle): src_data = [] with open(test_data_csv_png_10) as f: for l in f.readlines(): values = [x.strip() for x in l.split(',')] img_file_name = os.path.join( os.path.dirname(test_data_csv_png_10), values[0]) if os.path.exists(img_file_name): with open(img_file_name, 'rb') as img_file: d = load_image(img_file) src_data.append((d, [int(values[1])])) def test_load_func(position): return src_data[position] size = len(src_data) with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle) as di: check_data_iterator_result(di, batch_size, shuffle, False)
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle, stop_exhausted): src_data = [] with open(test_data_csv_png_10) as f: for l in f.readlines(): values = [x.strip() for x in l.split(',')] img_file_name = os.path.join( os.path.dirname(test_data_csv_png_10), values[0]) if os.path.exists(img_file_name): with open(img_file_name, 'rb') as img_file: d = load_image(img_file) src_data.append((d, [int(values[1])])) def test_load_func(position): return src_data[position] size = len(src_data) with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle, stop_exhausted=stop_exhausted) as di: check_data_iterator_result( di, batch_size, shuffle, False, stop_exhausted)
def test_data_iterator_simple(test_data_csv_png_10, batch_size, shuffle, stop_exhausted): src_data = [] with open(test_data_csv_png_10) as f: for l in f.readlines(): values = [x.strip() for x in l.split(',')] img_file_name = os.path.join(os.path.dirname(test_data_csv_png_10), values[0]) if os.path.exists(img_file_name): with open(img_file_name, 'rb') as img_file: d = load_image(img_file) src_data.append((d, [int(values[1])])) def test_load_func(position): return src_data[position] def end_epoch(epoch): print(f"{epoch} == {expect_epoch[0]}") assert epoch == expect_epoch[0], "Failed for end epoch check" assert threading.current_thread( ).ident == main_thread, "Failed for thread checking" def begin_epoch(epoch): print(f"{epoch} == {expect_epoch[0]}") assert epoch == expect_epoch[0], "Failed for begin epoch check" assert threading.current_thread( ).ident == main_thread, "Failed for thread checking" size = len(src_data) main_thread = threading.current_thread().ident expect_epoch = [0] with data_iterator_simple(test_load_func, size, batch_size, shuffle=shuffle, stop_exhausted=stop_exhausted) as di: if batch_size // size == 0: di.register_epoch_end_callback(begin_epoch) di.register_epoch_end_callback(end_epoch) di.register_epoch_end_callback(begin_epoch) di.register_epoch_end_callback(end_epoch) check_data_iterator_result(di, batch_size, shuffle, False, stop_exhausted, expect_epoch)
def _get_data(self, _position): position = self._indexes[_position] img = load_image(self.paths[position]) if img.shape[0] == 1: img = img.repeat(3, 0) return (img, self.labels[position])