def test_temp_file_removed_on_error(self): cache_filename = self.TEST_DIR / "cache_file" with pytest.raises(IOError, match="I made this up"): with CacheFile(cache_filename) as handle: raise IOError("I made this up") assert not os.path.exists(handle.name) assert not os.path.exists(cache_filename)
def build_vocab_from_args(args: argparse.Namespace): if not args.output_path.endswith(".tar.gz"): raise ValueError("param 'output_path' should end with '.tar.gz'") if os.path.exists(args.output_path) and not args.force: raise RuntimeError(f"{args.output_path} already exists. Use --force to overwrite.") output_directory = os.path.dirname(args.output_path) os.makedirs(output_directory, exist_ok=True) params = Params.from_file(args.param_path) with tempfile.TemporaryDirectory() as temp_dir: # Serializes the vocab to 'tempdir/vocabulary'. make_vocab_from_params(params, temp_dir) # The CacheFile context manager gives us a temporary file to write to. # On a successful exit from the context, it will rename the temp file to # the target `output_path`. with CacheFile(args.output_path, suffix=".tar.gz") as temp_archive: logger.info("Archiving vocabulary to %s", args.output_path) with tarfile.open(temp_archive.name, "w:gz") as archive: vocab_dir = os.path.join(temp_dir, "vocabulary") for fname in os.listdir(vocab_dir): if fname.endswith(".lock"): continue archive.add(os.path.join(vocab_dir, fname), arcname=fname) print(f"Success! Vocab saved to {args.output_path}") print('You can now set the "vocabulary" entry of your training config to:') print(json.dumps({"type": "from_files", "directory": os.path.abspath(args.output_path)}))
def _instances_to_cache_file(self, cache_filename, instances) -> None: # We serialize to a temp file first in case anything goes wrong while # writing to cache (e.g., the computer shuts down unexpectedly). # Then we just copy the file over to `cache_filename`. with CacheFile(cache_filename, mode="w+") as cache_handle: logger.info("Caching instances to temp file %s", cache_handle.name) for instance in Tqdm.tqdm(instances, desc="caching instances"): cache_handle.write(self.serialize_instance(instance) + "\n")
def _instance_iterator(self, file_path: str) -> Iterable[Instance]: cache_file: Optional[str] = None if self._cache_directory: cache_file = self._get_cache_location_for_file_path(file_path) if cache_file is not None and os.path.exists(cache_file): cache_file_lock = FileLock(cache_file + ".lock", timeout=self.CACHE_FILE_LOCK_TIMEOUT) try: cache_file_lock.acquire() # We make an assumption here that if we can obtain the lock, no one will # be trying to write to the file anymore, so it should be safe to release the lock # before reading so that other processes can also read from it. cache_file_lock.release() logger.info("Reading instances from cache %s", cache_file) with open(cache_file) as data_file: yield from self._multi_worker_islice( data_file, transform=self.deserialize_instance) except Timeout: logger.warning( "Failed to acquire lock on dataset cache file within %d seconds. " "Cannot use cache to read instances.", self.CACHE_FILE_LOCK_TIMEOUT, ) yield from self._multi_worker_islice(self._read(file_path), ensure_lazy=True) elif cache_file is not None and not os.path.exists(cache_file): instances = self._multi_worker_islice(self._read(file_path), ensure_lazy=True) # The cache file doesn't exist so we'll try writing to it. if self.max_instances is not None: # But we don't write to the cache when max_instances is specified. logger.warning( "Skipping writing to data cache since max_instances was specified." ) yield from instances elif util.is_distributed() or (get_worker_info() and get_worker_info().num_workers): # We also shouldn't write to the cache if there's more than one process loading # instances since each worker only receives a partial share of the instances. logger.warning( "Can't cache data instances when there are multiple processes loading data" ) yield from instances else: try: with FileLock(cache_file + ".lock", timeout=self.CACHE_FILE_LOCK_TIMEOUT): with CacheFile(cache_file, mode="w+") as cache_handle: logger.info("Caching instances to temp file %s", cache_handle.name) for instance in instances: cache_handle.write( self.serialize_instance(instance) + "\n") yield instance except Timeout: logger.warning( "Failed to acquire lock on dataset cache file within %d seconds. " "Cannot write to cache.", self.CACHE_FILE_LOCK_TIMEOUT, ) yield from instances else: # No cache. yield from self._multi_worker_islice(self._read(file_path), ensure_lazy=True)