예제 #1
0
 def test_temp_file_removed_on_error(self):
     cache_filename = self.TEST_DIR / "cache_file"
     with pytest.raises(IOError, match="I made this up"):
         with CacheFile(cache_filename) as handle:
             raise IOError("I made this up")
     assert not os.path.exists(handle.name)
     assert not os.path.exists(cache_filename)
예제 #2
0
def build_vocab_from_args(args: argparse.Namespace):
    if not args.output_path.endswith(".tar.gz"):
        raise ValueError("param 'output_path' should end with '.tar.gz'")

    if os.path.exists(args.output_path) and not args.force:
        raise RuntimeError(f"{args.output_path} already exists. Use --force to overwrite.")

    output_directory = os.path.dirname(args.output_path)
    os.makedirs(output_directory, exist_ok=True)

    params = Params.from_file(args.param_path)

    with tempfile.TemporaryDirectory() as temp_dir:
        # Serializes the vocab to 'tempdir/vocabulary'.
        make_vocab_from_params(params, temp_dir)

        # The CacheFile context manager gives us a temporary file to write to.
        # On a successful exit from the context, it will rename the temp file to
        # the target `output_path`.
        with CacheFile(args.output_path, suffix=".tar.gz") as temp_archive:
            logger.info("Archiving vocabulary to %s", args.output_path)

            with tarfile.open(temp_archive.name, "w:gz") as archive:
                vocab_dir = os.path.join(temp_dir, "vocabulary")
                for fname in os.listdir(vocab_dir):
                    if fname.endswith(".lock"):
                        continue
                    archive.add(os.path.join(vocab_dir, fname), arcname=fname)

    print(f"Success! Vocab saved to {args.output_path}")
    print('You can now set the "vocabulary" entry of your training config to:')
    print(json.dumps({"type": "from_files", "directory": os.path.abspath(args.output_path)}))
예제 #3
0
 def _instances_to_cache_file(self, cache_filename, instances) -> None:
     # We serialize to a temp file first in case anything goes wrong while
     # writing to cache (e.g., the computer shuts down unexpectedly).
     # Then we just copy the file over to `cache_filename`.
     with CacheFile(cache_filename, mode="w+") as cache_handle:
         logger.info("Caching instances to temp file %s", cache_handle.name)
         for instance in Tqdm.tqdm(instances, desc="caching instances"):
             cache_handle.write(self.serialize_instance(instance) + "\n")
예제 #4
0
    def _instance_iterator(self, file_path: str) -> Iterable[Instance]:
        cache_file: Optional[str] = None
        if self._cache_directory:
            cache_file = self._get_cache_location_for_file_path(file_path)

        if cache_file is not None and os.path.exists(cache_file):
            cache_file_lock = FileLock(cache_file + ".lock",
                                       timeout=self.CACHE_FILE_LOCK_TIMEOUT)
            try:
                cache_file_lock.acquire()
                # We make an assumption here that if we can obtain the lock, no one will
                # be trying to write to the file anymore, so it should be safe to release the lock
                # before reading so that other processes can also read from it.
                cache_file_lock.release()
                logger.info("Reading instances from cache %s", cache_file)
                with open(cache_file) as data_file:
                    yield from self._multi_worker_islice(
                        data_file, transform=self.deserialize_instance)
            except Timeout:
                logger.warning(
                    "Failed to acquire lock on dataset cache file within %d seconds. "
                    "Cannot use cache to read instances.",
                    self.CACHE_FILE_LOCK_TIMEOUT,
                )
                yield from self._multi_worker_islice(self._read(file_path),
                                                     ensure_lazy=True)
        elif cache_file is not None and not os.path.exists(cache_file):
            instances = self._multi_worker_islice(self._read(file_path),
                                                  ensure_lazy=True)
            # The cache file doesn't exist so we'll try writing to it.
            if self.max_instances is not None:
                # But we don't write to the cache when max_instances is specified.
                logger.warning(
                    "Skipping writing to data cache since max_instances was specified."
                )
                yield from instances
            elif util.is_distributed() or (get_worker_info()
                                           and get_worker_info().num_workers):
                # We also shouldn't write to the cache if there's more than one process loading
                # instances since each worker only receives a partial share of the instances.
                logger.warning(
                    "Can't cache data instances when there are multiple processes loading data"
                )
                yield from instances
            else:
                try:
                    with FileLock(cache_file + ".lock",
                                  timeout=self.CACHE_FILE_LOCK_TIMEOUT):
                        with CacheFile(cache_file, mode="w+") as cache_handle:
                            logger.info("Caching instances to temp file %s",
                                        cache_handle.name)
                            for instance in instances:
                                cache_handle.write(
                                    self.serialize_instance(instance) + "\n")
                                yield instance
                except Timeout:
                    logger.warning(
                        "Failed to acquire lock on dataset cache file within %d seconds. "
                        "Cannot write to cache.",
                        self.CACHE_FILE_LOCK_TIMEOUT,
                    )
                    yield from instances
        else:
            # No cache.
            yield from self._multi_worker_islice(self._read(file_path),
                                                 ensure_lazy=True)