def _read(self, file_path):
     # if `file_path` is a URL, redirect to the cache
     file_path = cached_path(file_path)
     logger.info("Reading instances from lines in file at: %s", file_path)
     for amr in AMRIO.read(file_path):
         yield self.text_to_instance(amr)
     self.report_coverage()
Ejemplo n.º 2
0
    def __init__(self,
                 file_uri: str,
                 encoding: str = DEFAULT_ENCODING,
                 cache_dir: str = None) -> None:

        self.uri = file_uri
        self._encoding = encoding
        self._cache_dir = cache_dir
        self._archive_handle: Any = None  # only if the file is inside an archive

        main_file_uri, path_inside_archive = parse_embeddings_file_uri(
            file_uri)
        main_file_local_path = cached_path(main_file_uri, cache_dir=cache_dir)

        if zipfile.is_zipfile(main_file_local_path):  # ZIP archive
            self._open_inside_zip(main_file_uri, path_inside_archive)

        elif tarfile.is_tarfile(main_file_local_path):  # TAR archive
            self._open_inside_tar(main_file_uri, path_inside_archive)

        else:  # all the other supported formats, including uncompressed files
            if path_inside_archive:
                raise ValueError('Unsupported archive format: %s' +
                                 main_file_uri)

            # All the python packages for compressed files share the same interface of io.open
            extension = get_file_extension(main_file_uri)
            package = {
                '.txt': io,
                '.vec': io,
                '.gz': gzip,
                '.bz2': bz2,
                '.lzma': lzma,
            }.get(extension, None)

            if package is None:
                logger.warning(
                    'The embeddings file has an unknown file extension "%s". '
                    'We will assume the file is an (uncompressed) text file',
                    extension)
                package = io

            self._handle = package.open(main_file_local_path,
                                        'rt',
                                        encoding=encoding)  # type: ignore

        # To use this with tqdm we'd like to know the number of tokens. It's possible that the
        # first line of the embeddings file contains this: if it does, we want to start iteration
        # from the 2nd line, otherwise we want to start from the 1st.
        # Unfortunately, once we read the first line, we cannot move back the file iterator
        # because the underlying file may be "not seekable"; we use itertools.chain instead.
        first_line = next(self._handle)  # this moves the iterator forward
        self.num_tokens = EmbeddingsTextFile._get_num_tokens_from_first_line(
            first_line)
        if self.num_tokens:
            # the first line is a header line: start iterating from the 2nd line
            self._iterator = self._handle
        else:
            # the first line is not a header line: start iterating from the 1st line
            self._iterator = itertools.chain([first_line], self._handle)
Ejemplo n.º 3
0
def load_archive(archive_file: str,
                 device=None,
                 weights_file: str = None) -> Archive:
    """
    Instantiates an Archive from an archived `tar.gz` file.
    Parameters
    ----------
    archive_file: ``str``
        The archive file to load the model from.
    weights_file: ``str``, optional (default = None)
        The weights file to use.  If unspecified, weights.th in the archive_file will be used.
    device: ``None`` or PyTorch device object.
    """
    # redirect to the cache, if necessary
    resolved_archive_file = cached_path(archive_file)

    if resolved_archive_file == archive_file:
        logger.info(f"loading archive file {archive_file}")
    else:
        logger.info(
            f"loading archive file {archive_file} from cache at {resolved_archive_file}"
        )

    tempdir = None
    if os.path.isdir(resolved_archive_file):
        serialization_dir = resolved_archive_file
    else:
        # Extract archive to temp dir
        tempdir = tempfile.mkdtemp()
        logger.info(
            f"extracting archive file {resolved_archive_file} to temp dir {tempdir}"
        )
        with tarfile.open(resolved_archive_file, 'r:gz') as archive:
            archive.extractall(tempdir)

        serialization_dir = tempdir

    # Load config
    config = Params.from_file(os.path.join(serialization_dir, CONFIG_NAME))
    config.loading_from_archive = True

    if weights_file:
        weights_path = weights_file
    else:
        weights_path = os.path.join(serialization_dir, _WEIGHTS_NAME)

    # Instantiate model. Use a duplicate of the config, as it will get consumed.
    model = Model.load(config,
                       weights_file=weights_path,
                       serialization_dir=serialization_dir,
                       device=device)

    if tempdir:
        # Clean up temp dir
        shutil.rmtree(tempdir)

    return Archive(model=model, config=config)
Ejemplo n.º 4
0
 def _open_inside_zip(self,
                      archive_path: str,
                      member_path: Optional[str] = None) -> None:
     cached_archive_path = cached_path(archive_path,
                                       cache_dir=self._cache_dir)
     archive = zipfile.ZipFile(cached_archive_path, 'r')
     if member_path is None:
         members_list = archive.namelist()
         member_path = self._get_the_only_file_in_the_archive(
             members_list, archive_path)
     member_path = cast(str, member_path)
     member_file = archive.open(member_path, 'r')
     self._handle = io.TextIOWrapper(member_file, encoding=self._encoding)
     self._archive_handle = archive
Ejemplo n.º 5
0
 def _open_inside_tar(self,
                      archive_path: str,
                      member_path: Optional[str] = None) -> None:
     cached_archive_path = cached_path(archive_path,
                                       cache_dir=self._cache_dir)
     archive = tarfile.open(cached_archive_path, 'r')
     if member_path is None:
         members_list = archive.getnames()
         member_path = self._get_the_only_file_in_the_archive(
             members_list, archive_path)
     member_path = cast(str, member_path)
     member = archive.getmember(
         member_path)  # raises exception if not present
     member_file = cast(IO[bytes], archive.extractfile(member))
     self._handle = io.TextIOWrapper(member_file, encoding=self._encoding)
     self._archive_handle = archive
    def __init__(self,
                 encoder: Dict[str, int] = None,
                 byte_pairs: List[Tuple[str, str]] = None,
                 n_ctx: int = 512,
                 model_path: str = None,
                 namespace: str = 'openai_transformer',
                 tokens_to_add: List[str] = None) -> None:
        self._namespace = namespace
        self._added_to_vocabulary = False

        too_much_information = model_path and (encoder or byte_pairs)
        too_little_information = not model_path and not (encoder
                                                         and byte_pairs)

        if too_much_information or too_little_information:
            raise ConfigurationError(
                "must specify either model path or (encoder + byte_pairs) but not both"
            )

        if model_path:
            model_path = cached_path(model_path)

            # Load encoder and byte_pairs from tar.gz
            with tarfile.open(model_path) as tmp:
                encoder_name = next(m.name for m in tmp.getmembers()
                                    if 'encoder_bpe' in m.name)
                encoder_info = tmp.extractfile(encoder_name)

                if encoder_info:
                    encoder = json.loads(encoder_info.read())
                else:
                    raise ConfigurationError(
                        f"expected encoder_bpe file in archive {model_path}")

                bpe_name = next(m.name for m in tmp.getmembers()
                                if m.name.endswith('.bpe'))
                bpe_info = tmp.extractfile(bpe_name)

                if bpe_info:
                    # First line is "version", last line is blank
                    lines = bpe_info.read().decode('utf-8').split('\n')[1:-1]
                    # Convert "b1 b2" -> (b1, b2)
                    byte_pairs = [tuple(line.split())
                                  for line in lines]  # type: ignore
                else:
                    raise ConfigurationError(
                        f"expected .bpe file in archive {model_path}")

        if tokens_to_add is not None:
            for token in tokens_to_add:
                encoder[token + '</w>'] = len(encoder)
            self.tokens_to_add = set(tokens_to_add)
        else:
            self.tokens_to_add = None

        self.encoder = encoder
        self.decoder = {
            word_id: word
            for word, word_id in self.encoder.items()
        }

        # Compute ranks
        self.bpe_ranks = {pair: idx for idx, pair in enumerate(byte_pairs)}

        self.cache: Dict[str, List[str]] = {}
        self.n_ctx = n_ctx