예제 #1
0
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility).

    If doing single-GPU training or if the checkpoint is only being loaded by at
    most one process on each node (current default behavior is for only rank 0
    to read the checkpoint from disk), load_on_all_ranks should be False to
    avoid errors from torch.distributed not having been initialized or
    torch.distributed.barrier() hanging.

    If all processes on each node may be loading the checkpoint
    simultaneously, load_on_all_ranks should be set to True to avoid I/O
    conflicts.

    There's currently no support for > 1 but < all processes loading the
    checkpoint on each node.
    """
    local_path = PathManager.get_local_path(path)
    # The locally cached file returned by get_local_path() may be stale for
    # remote files that are periodically updated/overwritten (ex:
    # checkpoint_last.pt) - so we remove the local copy, sync across processes
    # (if needed), and then download a fresh copy.
    if local_path != path and PathManager.path_requires_pathmanager(path):
        try:
            os.remove(local_path)
        except FileNotFoundError:
            # With potentially multiple processes removing the same file, the
            # file being missing is benign (missing_ok isn't available until
            # Python 3.8).
            pass
        if load_on_all_ranks:
            torch.distributed.barrier()
        local_path = PathManager.get_local_path(path)

    with open(local_path, "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))

    if "args" in state and state["args"] is not None and arg_overrides is not None:
        args = state["args"]
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)

    if "cfg" in state and state["cfg"] is not None:

        # hack to be able to set Namespace in dict config. this should be removed when we update to newer
        # omegaconf version that supports object flags, or when we migrate all existing models
        from omegaconf import _utils

        old_primitive = _utils.is_primitive_type
        _utils.is_primitive_type = lambda _: True

        state["cfg"] = OmegaConf.create(state["cfg"])

        _utils.is_primitive_type = old_primitive
        OmegaConf.set_struct(state["cfg"], True)

        if arg_overrides is not None:
            overwrite_args_by_name(state["cfg"], arg_overrides)

    state = _upgrade_state_dict(state)
    return state
예제 #2
0
def get_indexed_dataset_to_local(path):
    local_index_path = PathManager.get_local_path(index_file_path(path))
    local_data_path = PathManager.get_local_path(data_file_path(path))

    assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), \
           "PathManager.get_local_path does not return files with expected patterns: " \
           f"{local_index_path} and {local_data_path}"

    local_path = local_data_path[:-4]  # stripping surfix ".bin"
    assert local_path == local_index_path[:-4]  # stripping surfix ".idx"
    return local_path
예제 #3
0
 def load_langs(cls, args, **kwargs):
     if args.lang_dict and args.langs:
         raise ValueError(
             "--langs and --lang-dict can not both be specified")
     if args.lang_dict is None and args.langs is None:
         logger.warning(
             "External language dictionary is not provided; "
             "use lang-pairs to infer the set of supported languages. "
             "The language ordering is not stable which might cause "
             "misalignment in pretraining and finetuning.")
         # infer from lang_pairs as it is
         langs = list({
             x
             for lang_pair in args.lang_pairs for x in lang_pair.split("-")
         })
         langs = sorted(langs)
         logger.info(f"inferred language list: {langs}")
     elif args.lang_dict:
         with open(PathManager.get_local_path(args.lang_dict),
                   "r",
                   encoding="utf-8") as f:
             langs = [
                 lang.strip() for lang in f.readlines() if lang.strip()
             ]
             logger.info(
                 f"loaded language list from {args.lang_dict} as they are ordered in file"
             )
     elif args.langs:
         langs = args.langs
         logger.info(
             f"parsed the language list as they are ordered in the option: {langs}"
         )
     return langs
예제 #4
0
    def add_file_to_dictionary(filename, dict, tokenize, num_workers):
        def merge_result(counter):
            for w, c in sorted(counter.items()):
                dict.add_symbol(w, c)

        local_file = PathManager.get_local_path(filename)
        offsets = find_offsets(local_file, num_workers)
        if num_workers > 1:
            chunks = zip(offsets, offsets[1:])
            pool = Pool(processes=num_workers)
            results = []
            for (start_offset, end_offset) in chunks:
                results.append(
                    pool.apply_async(
                        Dictionary._add_file_to_dictionary_single_worker,
                        (
                            local_file,
                            tokenize,
                            dict.eos_word,
                            start_offset,
                            end_offset,
                        ),
                    )
                )
            pool.close()
            pool.join()
            for r in results:
                merge_result(r.get())
        else:
            merge_result(
                Dictionary._add_file_to_dictionary_single_worker(
                    local_file, tokenize, dict.eos_word, offsets[0], offsets[1]
                )
            )
예제 #5
0
파일: binarizer.py 프로젝트: tma15/fairseq
    def _binarize_file_chunk(
        binarizer: Binarizer,
        filename: str,
        offset_start: int,
        offset_end: int,
        output_prefix: str,
        dataset_impl: str,
        vocab_size=None,
    ) -> tp.Tuple[tp.Any, BinarizeSummary]:  # (dataset builder, BinarizeSummary)
        """
        creates a dataset builder and append binarized items to it. This function does not
        finalize the builder, this is useful if you want to do other things with your bin file
        like appending/merging other files
        """
        bin_file = indexed_dataset.data_file_path(output_prefix)
        ds = indexed_dataset.make_builder(
            bin_file,
            impl=dataset_impl,
            vocab_size=vocab_size,
        )
        summary = BinarizeSummary()

        with Chunker(
            PathManager.get_local_path(filename), offset_start, offset_end
        ) as line_iterator:
            for line in line_iterator:
                ds.add_item(binarizer.binarize_line(line, summary))

        return ds, summary
예제 #6
0
 def _add_file_to_dictionary_single_worker(
     filename, tokenize, eos_word, worker_id=0, num_workers=1
 ):
     counter = Counter()
     with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
         size = os.fstat(f.fileno()).st_size
         chunk_size = size // num_workers
         offset = worker_id * chunk_size
         end = offset + chunk_size
         f.seek(offset)
         if offset > 0:
             safe_readline(f)  # drop first incomplete line
         line = f.readline()
         while line:
             for word in tokenize(line):
                 counter.update([word])
             counter.update([eos_word])
             # f.tell() returns only an opaque number which can
             # return to the position in the file via f.seek()
             # and does not necessarily represent a byte position
             # in the file. However, f.tell() is faithful to the
             # byte position _most of the time_. Thus we can just
             # check against the file size to prevent early exit.
             if f.tell() > end and f.tell() < size:
                 break
             line = f.readline()
     return counter
    def _add_file_to_dictionary_single_worker(filename,
                                              tokenize,
                                              eos_word,
                                              worker_id=0,
                                              num_workers=1,
                                              L=None):
        # This method is modified as it takes first L words in each line.

        counter = Counter()
        with open(PathManager.get_local_path(filename), "r",
                  encoding="utf-8") as f:
            size = os.fstat(f.fileno()).st_size
            chunk_size = size // num_workers
            offset = worker_id * chunk_size
            end = offset + chunk_size
            f.seek(offset)
            if offset > 0:
                safe_readline(f)  # drop first incomplete line
            line = f.readline()
            while line:
                for word in tokenize(line[:L]):
                    counter.update([word])
                counter.update([eos_word])
                if f.tell() > end:
                    break
                line = f.readline()
        return counter
예제 #8
0
    def binarize(
        filename,
        dict,
        consumer,
        tokenize=tokenize_line,
        append_eos=True,
        reverse_order=False,
        offset=0,
        end=-1,
        already_numberized=False,
    ) -> Dict[str, int]:
        nseq, ntok = 0, 0
        replaced = Counter()

        def replaced_consumer(word, idx):
            if idx == dict.unk_index and word != dict.unk_word:
                replaced.update([word])

        with open(PathManager.get_local_path(filename), "r",
                  encoding="utf-8") as f:
            f.seek(offset)
            # next(f) breaks f.tell(), hence readline() must be used
            line = safe_readline(f)
            while line:
                # f.tell() does not always give the byte position in the file
                # sometimes it skips to a very large number
                # it is unlikely that through a normal read we go from
                # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
                # that the procedure breaks by the undeterministic behavior of
                # f.tell()
                if end > 0 and f.tell() > end and f.tell() < end + 2**32:
                    break
                if already_numberized:
                    id_strings = line.strip().split()
                    id_list = [int(id_string) for id_string in id_strings]
                    if reverse_order:
                        id_list.reverse()
                    if append_eos:
                        id_list.append(dict.eos())
                    ids = torch.IntTensor(id_list)
                else:
                    ids = dict.encode_line(
                        line=line,
                        line_tokenizer=tokenize,
                        add_if_not_exist=False,
                        consumer=replaced_consumer,
                        append_eos=append_eos,
                        reverse_order=reverse_order,
                    )
                nseq += 1
                ntok += len(ids)
                consumer(ids)
                line = f.readline()
        return {
            "nseq": nseq,
            "nunk": sum(replaced.values()),
            "ntok": ntok,
            "replaced": replaced,
        }
    def binarize(
        filename,
        dict,
        consumer,
        tokenize=tokenize_line,
        append_eos=True,
        reverse_order=False,
        offset=0,
        end=-1,
        already_numberized=False,
    ):

        nseq, ntok = 0, 0
        replaced = Counter()

        def replaced_consumer(word, idx):
            if idx == dict.unk_index and word != dict.unk_word:
                replaced.update([word])

        with open(PathManager.get_local_path(filename), "r",
                  encoding="utf-8") as f:
            f.seek(offset)
            # next(f) breaks f.tell(), hence readline() must be used
            line = safe_readline(f)
            while line:

                if end > 0 and f.tell() > end:
                    break
                if already_numberized:
                    id_strings = line.strip().split()
                    id_list = [int(id_string) for id_string in id_strings]
                    if reverse_order:
                        id_list.reverse()
                    if append_eos:
                        id_list.append(dict.eos())
                    ids = torch.IntTensor(id_list)
                else:
                    ids = dict.encode_line(
                        line=line,
                        line_tokenizer=tokenize,
                        add_if_not_exist=False,
                        consumer=replaced_consumer,
                        append_eos=append_eos,
                        reverse_order=reverse_order,
                    )

                nseq += 1
                ntok += len(ids)
                consumer(ids)
                line = f.readline()
        return {
            "nseq": nseq,
            "nunk": sum(replaced.values()),
            "ntok": ntok,
            "replaced": replaced,
        }
예제 #10
0
 def find_offsets(filename, num_chunks):
     with open(PathManager.get_local_path(filename), "r",
               encoding="utf-8") as f:
         size = os.fstat(f.fileno()).st_size
         chunk_size = size // num_chunks
         offsets = [0 for _ in range(num_chunks + 1)]
         for i in range(1, num_chunks):
             f.seek(chunk_size * i)
             safe_readline(f)
             offsets[i] = f.tell()
         return offsets
예제 #11
0
def cached_path_from_pm(url_or_filename):
    """
    Tries to cache the specified URL using PathManager class.
    Returns the cached path if success otherwise failure.
    """
    try:
        from fairseq.file_io import PathManager
        local_path = PathManager.get_local_path(url_or_filename)
        return local_path
    except Exception:
        return None
예제 #12
0
def load_checkpoint_to_cpu(path, arg_overrides=None):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
    with open(PathManager.get_local_path(path), "rb") as f:
        state = torch.load(
            f, map_location=lambda s, l: default_restore_location(s, "cpu"))

    args = state["args"]
    if arg_overrides is not None:
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)
    state = _upgrade_state_dict(state)
    return state
예제 #13
0
 def binarize_tag(filename, consumer, offset=0, end=-1) -> Dict[str, int]:
     nseq = 0
     with open(PathManager.get_local_path(filename), "r") as f:
         f.seek(offset)
         line = safe_readline(f)
         while line:
             if end > 0 and f.tell() > end:
                 break
             ids = torch.IntTensor(
                 [int(num.strip()) + 1 for num in line.split(',')])
             nseq += 1
             consumer(ids)
             line = f.readline()
     return {"nseq": nseq}
예제 #14
0
파일: binarizer.py 프로젝트: sdadas/fairseq
    def binarize_alignments(filename,
                            alignment_parser,
                            consumer,
                            offset=0,
                            end=-1) -> Dict[str, int]:
        nseq = 0

        with Chunker(PathManager.get_local_path(filename), offset,
                     end) as line_iterator:
            for line in line_iterator:
                ids = alignment_parser(line)
                nseq += 1
                consumer(ids)
        return {"nseq": nseq}
예제 #15
0
def load_checkpoint_to_cpu(path, arg_overrides=None):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
    with open(PathManager.get_local_path(path), "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))

    if "args" in state and state["args"] is not None and arg_overrides is not None:
        args = state["args"]
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)

    if "cfg" in state and state["cfg"] is not None and arg_overrides is not None:
        overwrite_args_by_name(state["cfg"], arg_overrides)

    state = _upgrade_state_dict(state)
    return state
예제 #16
0
파일: binarizer.py 프로젝트: sdadas/fairseq
    def binarize(
        filename,
        dict,
        consumer,
        tokenize=tokenize_line,
        append_eos=True,
        reverse_order=False,
        offset=0,
        end=-1,
        already_numberized=False,
    ) -> Dict[str, int]:
        nseq, ntok = 0, 0
        replaced = Counter()

        def replaced_consumer(word, idx):
            if idx == dict.unk_index and word != dict.unk_word:
                replaced.update([word])

        with Chunker(PathManager.get_local_path(filename), offset,
                     end) as line_iterator:
            for line in line_iterator:
                if already_numberized:
                    id_strings = line.strip().split()
                    id_list = [int(id_string) for id_string in id_strings]
                    if reverse_order:
                        id_list.reverse()
                    if append_eos:
                        id_list.append(dict.eos())
                    ids = torch.IntTensor(id_list)
                else:
                    ids = dict.encode_line(
                        line=line,
                        line_tokenizer=tokenize,
                        add_if_not_exist=False,
                        consumer=replaced_consumer,
                        append_eos=append_eos,
                        reverse_order=reverse_order,
                    )
                nseq += 1
                ntok += len(ids)
                consumer(ids)
        return {
            "nseq": nseq,
            "nunk": sum(replaced.values()),
            "ntok": ntok,
            "replaced": replaced,
        }
예제 #17
0
    def add_from_file(self, f):
        """
        Loads a pre-existing dictionary from a text file and adds its symbols
        to this instance.
        """
        if isinstance(f, str):
            try:
                print("RAV MSG CATCH", f)
                if (f == '/home/ubuntu/project/manifest/train/dict.ltr.txt'):
                    f = '/home/ras306/Classwork/Project/Wav2Vec/model/dict.ltr.txt'
                with open(PathManager.get_local_path(f), "r",
                          encoding="utf-8") as fd:
                    self.add_from_file(fd)
            except FileNotFoundError as fnfe:
                raise fnfe
            except UnicodeError:
                raise Exception("Incorrect encoding detected in {}, please "
                                "rebuild the dataset".format(f))
            return

        lines = f.readlines()
        indices_start_line = self._load_meta(lines)

        for line in lines[indices_start_line:]:
            try:
                line, field = line.rstrip().rsplit(" ", 1)
                if field == "#fairseq:overwrite":
                    overwrite = True
                    line, field = line.rsplit(" ", 1)
                else:
                    overwrite = False
                count = int(field)
                word = line
                if word in self and not overwrite:
                    raise RuntimeError(
                        "Duplicate word found when loading Dictionary: '{}'. "
                        "Duplicate words can overwrite earlier ones by adding the "
                        "#fairseq:overwrite flag at the end of the corresponding row "
                        "in the dictionary file. If using the Camembert model, please "
                        "download an updated copy of the model file.".format(
                            word))
                self.add_symbol(word, n=count, overwrite=overwrite)
            except ValueError:
                raise ValueError(
                    "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
                )
예제 #18
0
    def binarize_alignments(filename,
                            alignment_parser,
                            consumer,
                            offset=0,
                            end=-1):
        nseq = 0

        with open(PathManager.get_local_path(filename), "r") as f:
            f.seek(offset)
            line = safe_readline(f)
            while line:
                if end > 0 and f.tell() > end:
                    break
                ids = alignment_parser(line)
                nseq += 1
                consumer(ids)
                line = f.readline()
        return {"nseq": nseq}
예제 #19
0
 def binarize_da(filename,
                 consumer,
                 da_mapping,
                 offset=0,
                 end=-1) -> Dict[str, int]:
     nseq = 0
     num_mapping = len(da_mapping)
     with open(PathManager.get_local_path(filename), "r") as f:
         f.seek(offset)
         line = safe_readline(f)
         while line:
             if end > 0 and f.tell() > end:
                 break
             ids = torch.IntTensor(
                 [da_mapping.get(line.strip(), num_mapping + 1)])
             nseq += 1
             consumer(ids)
             line = f.readline()
     return {"nseq": nseq}
예제 #20
0
파일: wer.py 프로젝트: valentinp72/espresso
 def add_ordered_utt_list(self, *args):
     if len(args) == 1 and isinstance(
             args[0], list):  # aleady a list of utterance ids
         self.ordered_utt_list = args[0]
         return
     self.ordered_utt_list = []
     for text_file in args:
         with open(PathManager.get_local_path(text_file),
                   "r",
                   encoding="utf-8") as f:
             one_utt_list = [line.strip().split()[0] for line in f]
             self.ordered_utt_list.extend(one_utt_list)
     if len(self.char_results):
         assert set(self.ordered_utt_list) == set(self.char_results.keys())
     if len(self.results):
         assert set(self.ordered_utt_list) == set(self.results.keys())
     if len(self.aligned_results):
         assert set(self.ordered_utt_list) == set(
             self.aligned_results.keys())
예제 #21
0
    def load(cls, f, enable_bos=False, f_non_lang_syms=None):
        """Loads the dictionary from a text file with the format:

        ```
        <symbol0> <count0>
        <symbol1> <count1>
        ...
        ```

        Optionally add bos symbol to the dictionary

        Identifies the space symbol if it exists, by obtaining its index
        (space_index=-1 if no space symbol)

        Loads non_lang_syms from another text file, if it exists, with one
        symbol per line
        """
        d = cls(enable_bos=enable_bos)
        d.add_from_file(f)

        d.space_index = d.indices.get(d.space_word, -1)

        if f_non_lang_syms is not None:
            assert isinstance(f_non_lang_syms, str)
            try:
                with open(PathManager.get_local_path(f_non_lang_syms),
                          "r",
                          encoding="utf-8") as fd:
                    non_lang_syms = [x.rstrip() for x in fd.readlines()]
            except FileNotFoundError as fnfe:
                raise fnfe
            except UnicodeError:
                raise Exception("Incorrect encoding detected in {}, please "
                                "rebuild the dataset".format(fd))

            for sym in non_lang_syms:
                assert (d.index(sym) !=
                        d.unk()), "{} in {} is not in the dictionary".format(
                            sym, f_non_lang_syms)
            d.non_lang_syms = non_lang_syms

        return d
예제 #22
0
def load_checkpoint_to_cpu(path, arg_overrides=None):
    """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
    print("RAV MSG CATCH: ", path)
    if (path == '/home/ubuntu/project/model/wav2vec_small.pt'):
        path = '/home/ras306/Classwork/Project/Wav2Vec/model/wav2vec_small.pt'
    with open(PathManager.get_local_path(path), "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))

    if "args" in state and state[
            "args"] is not None and arg_overrides is not None:
        args = state["args"]
        for arg_name, arg_val in arg_overrides.items():
            setattr(args, arg_name, arg_val)

    if "cfg" in state and state[
            "cfg"] is not None and arg_overrides is not None:
        overwrite_args_by_name(state["cfg"], arg_overrides)

    state = _upgrade_state_dict(state)
    return state
예제 #23
0
 def parse_wer_output_filter(self, wer_output_filter):
     if wer_output_filter:
         with open(PathManager.get_local_path(wer_output_filter),
                   'r',
                   encoding='utf-8') as f:
             for line in f:
                 line = line.strip()
                 if line.startswith('#!') or line == '':
                     continue
                 elif line.startswith('s/'):
                     m = re.match(r's/(.+)/(.*)/g', line)
                     assert m is not None
                     self.word_filters.append([m.group(1), m.group(2)])
                 elif line.startswith('s:'):
                     m = re.match(r's:(.+):(.*):g', line)
                     assert m is not None
                     self.word_filters.append([m.group(1), m.group(2)])
                 else:
                     logger.warning(
                         'Unsupported pattern: "{}". Ignoring it'.format(
                             line))
예제 #24
0
파일: wer.py 프로젝트: valentinp72/espresso
 def parse_wer_output_filter(self, wer_output_filter):
     if wer_output_filter:
         with open(PathManager.get_local_path(wer_output_filter),
                   "r",
                   encoding="utf-8") as f:
             for line in f:
                 line = line.strip()
                 if line.startswith("#!") or line == "":
                     continue
                 elif line.startswith("s/"):
                     m = re.match(r"s/(.+)/(.*)/g", line)
                     assert m is not None
                     self.word_filters.append([m.group(1), m.group(2)])
                 elif line.startswith("s:"):
                     m = re.match(r"s:(.+):(.*):g", line)
                     assert m is not None
                     self.word_filters.append([m.group(1), m.group(2)])
                 else:
                     logger.warning(
                         "Unsupported pattern: '{}'. Ignoring it".format(
                             line))
예제 #25
0
파일: binarizer.py 프로젝트: tma15/fairseq
    def _consume_file(
        filename: str,
        binarizer: Binarizer,
        consumer: tp.Callable[[torch.IntTensor], None],
        offset_start: int,
        offset_end: int,
    ) -> tp.Dict[str, int]:
        summary = BinarizeSummary()

        with Chunker(
            PathManager.get_local_path(filename), offset_start, offset_end
        ) as line_iterator:
            for line in line_iterator:
                consumer(binarizer.binarize_line(line, summary))

        return {
            "nseq": summary.num_seq,
            "nunk": summary.num_replaced,
            "ntok": summary.num_tok,
            "replaced": summary.replaced,
        }
예제 #26
0
def upgrade_state_dict_with_infoxlm_weights(
        state_dict: Dict[str, Any],
        pretrained_infoxlm_checkpoint: str,
        num_layers: int,
        shared_cross_attn: bool = False) -> Dict[str, Any]:
    """
    Load XLM weights into a Transformer encoder or decoder model.

    Args:
        state_dict: state dict for either TransformerEncoder or
            TransformerDecoder
        pretrained_infoxlm_checkpoint: checkpoint to load XLM weights from

    Raises:
        AssertionError: If architecture (num layers, attention heads, etc.)
            does not match between the current Transformer encoder or
            decoder and the pretrained_xlm_checkpoint
    """
    if not os.path.exists(pretrained_infoxlm_checkpoint):
        raise IOError(
            "Model file not found: {}".format(pretrained_infoxlm_checkpoint))

    # state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_infoxlm_checkpoint)
    with open(PathManager.get_local_path(pretrained_infoxlm_checkpoint),
              "rb") as f:
        state = torch.load(f, map_location=torch.device("cpu"))
    infoxlm_state_dict = state["model"]
    # print(state_dict.keys())

    for key in infoxlm_state_dict.keys():
        if 'layers' in key and int(key.split('.')[3]) > num_layers - 1:
            continue
        if not key.startswith('decoder.'):
            continue
        if 'lm_head' not in key:
            if 'in_proj_weight' in key:
                q, k, v = infoxlm_state_dict[key].chunk(3, dim=0)
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_weight',
                                                   'q_proj.weight')] = q
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_weight',
                                                   'k_proj.weight')] = k
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_weight',
                                                   'v_proj.weight')] = v
                if shared_cross_attn:
                    state_dict[key.replace(
                        'decoder.sentence_encoder.',
                        '').replace('in_proj_weight', 'q_proj.weight').replace(
                            'self_attn', 'encoder_attn')] = q
                    state_dict[key.replace(
                        'decoder.sentence_encoder.',
                        '').replace('in_proj_weight', 'k_proj.weight').replace(
                            'self_attn', 'encoder_attn')] = k
                    state_dict[key.replace(
                        'decoder.sentence_encoder.',
                        '').replace('in_proj_weight', 'v_proj.weight').replace(
                            'self_attn', 'encoder_attn')] = v
            elif 'in_proj_bias' in key:
                q, k, v = infoxlm_state_dict[key].chunk(3, dim=0)
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_bias',
                                                   'q_proj.bias')] = q
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_bias',
                                                   'k_proj.bias')] = k
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '').replace('in_proj_bias',
                                                   'v_proj.bias')] = v
                if shared_cross_attn:
                    state_dict[key.replace('decoder.sentence_encoder.',
                                           '').replace('in_proj_bias',
                                                       'q_proj.bias').replace(
                                                           'self_attn',
                                                           'encoder_attn')] = q
                    state_dict[key.replace('decoder.sentence_encoder.',
                                           '').replace('in_proj_bias',
                                                       'k_proj.bias').replace(
                                                           'self_attn',
                                                           'encoder_attn')] = k
                    state_dict[key.replace('decoder.sentence_encoder.',
                                           '').replace('in_proj_bias',
                                                       'v_proj.bias').replace(
                                                           'self_attn',
                                                           'encoder_attn')] = v
            elif 'emb_layer_norm' in key:
                state_dict[key.replace(
                    'decoder.sentence_encoder.emb_layer_norm',
                    'layernorm_embedding')] = infoxlm_state_dict[key]
            elif 'embed_positions' in key:
                state_dict[key.replace(
                    'decoder.sentence_encoder.',
                    '')] = infoxlm_state_dict[key][:state_dict[key.replace(
                        'decoder.sentence_encoder.', '')].size(0)]
            elif 'embed_tokens' in key:
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '')][:infoxlm_state_dict[key].
                                            size(0)] = infoxlm_state_dict[key]
            else:
                state_dict[key.replace('decoder.sentence_encoder.',
                                       '')] = infoxlm_state_dict[key]

    return state_dict
예제 #27
0
    def binarize(
        filename,
        dict,
        consumer,
        tokenize=tokenize_line,
        append_eos=True,
        reverse_order=False,
        offset=0,
        end=-1,
        already_numberized=False,
        avoid_tokenize=False,
    ) -> Dict[str, int]:
        nseq, ntok = 0, 0
        replaced = Counter()

        def replaced_consumer(word, idx):
            if idx == dict.unk_index and word != dict.unk_word:
                replaced.update([word])

        def replaced_consumer_from_pretrained(word, idx):
            if idx == dict.convert_tokens_to_ids(
                    dict.unk_token) and word != dict.unk_token:
                replaced.update([word])

        with open(PathManager.get_local_path(filename), "r",
                  encoding="utf-8") as f:
            f.seek(offset)
            # next(f) breaks f.tell(), hence readline() must be used
            line = safe_readline(f)
            while line:
                # f.tell() does not always give the byte position in the file
                # sometimes it skips to a very large number
                # it is unlikely that through a normal read we go from
                # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
                # that the procedure breaks by the undeterministic behavior of
                # f.tell()
                if end > 0 and f.tell() > end and f.tell() < end + 2**32:
                    break
                if already_numberized:
                    id_strings = line.strip().split()
                    id_list = [int(id_string) for id_string in id_strings]
                    if reverse_order:
                        id_list.reverse()
                    if append_eos:
                        id_list.append(dict.eos())
                    ids = torch.IntTensor(id_list)
                elif isinstance(dict, BertTokenizer) and not isinstance(
                        dict, ElectraTokenizer):
                    line = line.strip()
                    line = '{} {} {}'.format('[CLS]', line, '[SEP]')
                    if avoid_tokenize is False:
                        tokenizedline = dict.tokenize(line)
                    else:
                        tokenizedline = line.strip().split()
                    # max-len:1000000000000
                    # print('----------bert_max-len:' + str(dict.max_len) + '----------')
                    # if len(tokenizedline) > dict.max_len:
                    #     tokenizedline = tokenizedline[:dict.max_len - 1]
                    #     tokenizedline.append('[SEP]')
                    words = dict.convert_tokens_to_ids(tokenizedline)
                    nwords = len(words)
                    ids = torch.IntTensor(nwords)
                    for i, word in enumerate(words):
                        ids[i] = word
                        replaced_consumer_from_pretrained(
                            tokenizedline[i], word)
                elif isinstance(dict, BartTokenizer):
                    line = line.strip()
                    if avoid_tokenize is False:
                        # extra space at the end will cause weird outputs.
                        line = '{} {}{}'.format('<s>', line, '</s>')
                        tokenizedline = dict.tokenize(line)
                    else:
                        line = '{} {} {}'.format('<s>', line, '</s>')
                        tokenizedline = line.strip().split()
                    # tokenizedline = dict.tokenize(line)
                    words = dict.convert_tokens_to_ids(tokenizedline)
                    assert len(tokenizedline) == len(words)
                    nwords = len(words)
                    ids = torch.IntTensor(nwords)
                    for i, word in enumerate(words):
                        ids[i] = word
                        replaced_consumer_from_pretrained(
                            tokenizedline[i], word)
                elif isinstance(dict, ElectraTokenizer):
                    line = line.strip()
                    line = '{} {} {}'.format('[CLS]', line, '[SEP]')
                    if avoid_tokenize is False:
                        tokenizedline = dict.tokenize(line)
                    else:
                        tokenizedline = line.strip().split()
                    # max-len:1000000000000
                    # print('----------bert_max-len:' + str(dict.max_len) + '----------')
                    # if len(tokenizedline) > dict.max_len:
                    #     tokenizedline = tokenizedline[:dict.max_len - 1]
                    #     tokenizedline.append('[SEP]')
                    words = dict.convert_tokens_to_ids(tokenizedline)
                    #
                    # import pdb; pdb.set_trace()
                    nwords = len(words)
                    ids = torch.IntTensor(nwords)
                    for i, word in enumerate(words):
                        ids[i] = word
                        replaced_consumer_from_pretrained(
                            tokenizedline[i], word)
                elif dict is None:
                    line = line.strip()
                    words = line.split()
                    words = [int(item) for item in words]
                    nwords = len(words)
                    ids = torch.IntTensor(nwords)
                    for i, word in enumerate(words):
                        ids[i] = word
                else:
                    ids = dict.encode_line(
                        line=line,
                        line_tokenizer=tokenize,
                        add_if_not_exist=False,
                        consumer=replaced_consumer,
                        append_eos=append_eos,
                        reverse_order=reverse_order,
                    )
                nseq += 1
                ntok += len(ids)
                consumer(ids)
                line = f.readline()
        return {
            "nseq": nseq,
            "nunk": sum(replaced.values()),
            "ntok": ntok,
            "replaced": replaced,
        }