def test_simple_int(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') writer = Writer(path, compression=None) for index in range(1000): writer.append_record(str(index).encode(), index={'subset': 'train', 'subtask': 'domain'}) for index in range(500): writer.append_record(str(index).encode(), index={'subset': 'val', 'subtask': 'domain'}) writer.close() writer = Writer(path, compression=None) for index in range(100): writer.append_record(str(index).encode(), index={'subset': 'train', 'subtask': 'domain'}) for index in range(100): writer.append_record(str(index).encode(), index={'subset': 'val', 'subtask': 'domain'}) writer.close() reader = Reader(path) for index in range(1000): assert index == int(reader.get(index, {'subset': 'train', 'subtask': 'domain'})) for index in range(1000, 1100): assert index - 1000 == int(reader.get(index, {'subset': 'train', 'subtask': 'domain'})) for index in range(500): assert index == int(reader.get(index, {'subset': 'val', 'subtask': 'domain'})) for index in range(500, 600): assert index - 500 == int(reader.get(index, {'subset': 'val', 'subtask': 'domain'})) reader.close()
def open(self, artifact_name: str, *args, **kwargs): assert artifact_name not in self._artifacts artifact_path = self._get_artifact_path(artifact_name) if self._mode == self.WRITING_MODE: artifact_instance = Writer(artifact_path, *args, **kwargs) elif self._mode == self.READING_MODE: artifact_instance = ReaderIterator(artifact_path, *args, **kwargs) self._artifacts.update({ artifact_name: artifact_instance, })
def test_read_write_on_existing_data(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') writer = Writer(path, compression='gzip') length = 1000 for index in range(length): writer.append_record(str(index).encode()) writer.close() writer = Writer(path, rewrite=False, compression='gzip') for index in range(length, 2 * length): writer.append_record(str(index).encode()) writer.flush() reader = Reader(path, uncommitted_bucket_visible=True) assert reader.get_records_num() == index + 1 assert index == int(reader.get(index).decode()) reader.close() writer.close()
def test_modification_time(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') writer = Writer(path, compression=None) length = 100 for index in range(length): writer.append_record(str(index).encode()) writer.close() reader = Reader(path) first_mod_time = reader.get_modification_time() writer = Writer(path, compression=None) length = 100 for index in range(length): writer.append_record(str(index).encode()) writer.close() reader = Reader(path) second_mod_time = reader.get_modification_time() assert first_mod_time != second_mod_time
def test_write_mode_binary(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') length = 1000 writer = Writer(path, rewrite=True) for index in range(length // 2): writer.append_record(b'0') writer.close() writer = Writer(path, rewrite=True) for index in range(length // 2, length): entry = str(index).encode() writer.append_record(entry) writer.close() reader = Reader(path) assert reader.get_records_num() == length // 2 for index in range(length // 2, length): entry = str(index).encode() assert entry == reader.get(index - length // 2)
def test_gzip_compression(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') writer = Writer(path, compression='gzip') length = 1000 for i in range(length): writer.append_record(str(i).encode()) writer.close() reader = Reader(path) assert reader.get_records_num() == length for index in range(length): assert index == int(reader.get(index).decode())
def test_read_write(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') writer = Writer(path, compression=None) length = 1000 for index in range(length): writer.append_record(str(index).encode()) writer.flush() reader = Reader(path, uncommitted_bucket_visible=False) assert reader.get_records_num() == 0 reader.close() writer.close()
def test_simple_binary(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') writer = Writer(path, compression=None) length = 5000 for index in range(length): entry = (str(index) * index).encode() writer.append_record(entry) writer.close() reader = Reader(path) assert reader.get_records_num() == length for index in range(length): entry = (str(index) * index).encode() assert entry == reader.get(index)
def test_iterator_int(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') writer = Writer(path, compression=None) for index in range(1000): writer.append_record(str(index).encode(), index={'subset': 'train', 'subtask': 'domain'}) for index in range(500): writer.append_record(str(index).encode(), index={'subset': 'val', 'subtask': 'domain'}) for index in range(100): writer.append_record(str(index).encode(), index={'subset': 'train', 'subtask': 'domain'}) for index in range(100): writer.append_record(str(index).encode(), index={'subset': 'val', 'subtask': 'domain'}) writer.close() reader = ReaderIterator(path) reader.apply_index({'subset': 'train', 'subtask': 'domain'}) for i, record in enumerate(reader[0:1000]): assert i == int(record) for i, record in enumerate(reader[1000:1100]): assert i == int(record) reader.apply_index({'subset': 'val', 'subtask': 'domain'}) for i, record in enumerate(reader[0:500]): assert i == int(record) for i, record in enumerate(reader[500:600]): assert i == int(record) for i, record in enumerate(reader[0:600]): if i < 500: assert i == int(record) else: assert i - 500 == int(record) reader.close()
def test_append_mode_binary(self): with tempfile.TemporaryDirectory() as temp_dir: path = os.path.join(temp_dir, 'loss') length = 1000 chunks = 5 chunk_len = length // chunks for chunk in range(chunks): writer = Writer(path, rewrite=False) for index in range(chunk * chunk_len, (chunk + 1) * chunk_len): entry = str(index).encode() writer.append_record(entry) writer.close() reader = Reader(path) assert reader.get_records_num() == length for index in range(length): entry = str(index).encode() assert entry == reader.get(index)