Пример #1
0
class DummyMaskedLMTask(FairseqTask):
    def __init__(self, cfg: DummyMaskedLMConfig):
        super().__init__(cfg)

        self.dictionary = Dictionary()
        for i in range(cfg.dict_size):
            self.dictionary.add_symbol("word{}".format(i))
        logger.info("dictionary: {} types".format(len(self.dictionary)))
        # add mask token
        self.mask_idx = self.dictionary.add_symbol("<mask>")
        self.dictionary.pad_to_multiple_(8)  # often faster if divisible by 8

        mask_idx = 0
        pad_idx = 1
        seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1
        mask = torch.arange(2, cfg.tokens_per_sample, 7)  # ~15%
        src = seq.clone()
        src[mask] = mask_idx
        tgt = torch.full_like(seq, pad_idx)
        tgt[mask] = seq[mask]

        self.dummy_src = src
        self.dummy_tgt = tgt

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if self.cfg.batch_size is not None:
            bsz = self.cfg.batch_size
        else:
            bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
        self.datasets[split] = DummyDataset(
            {
                "id": 1,
                "net_input": {
                    "src_tokens":
                    torch.stack([self.dummy_src for _ in range(bsz)]),
                    "src_lengths":
                    torch.full(
                        (bsz, ), self.cfg.tokens_per_sample, dtype=torch.long),
                },
                "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
                "nsentences": bsz,
                "ntokens": bsz * self.cfg.tokens_per_sample,
            },
            num_items=self.cfg.dataset_size,
            item_size=self.cfg.tokens_per_sample,
        )

    @property
    def source_dictionary(self):
        return self.dictionary

    @property
    def target_dictionary(self):
        return self.dictionary
Пример #2
0
def ensure_symbols_are_present(dictionary: Dictionary, symbols: List[str],
                               ok_to_increase_dict_size: bool) -> None:
    """
    Ensure that the symbols in the source and target dictionary are present.

    Makes changes to the dictionaries in-place.
    """
    original_size = len(dictionary)
    _ = remove_madeupwords_from_dictionary(dictionary)
    for symbol in symbols:
        dictionary.add_symbol(symbol)
    dictionary.pad_to_multiple_(8)
    if not ok_to_increase_dict_size:
        # Let's not crash - but rather point out that we are not allowed to increase the dictionary size.
        if len(dictionary) != original_size:
            logger.warning(
                "The dictionary size changed. The model loading will probably fail."
            )
Пример #3
0
def augment_dictionary(
    dictionary: Dictionary,
    language_list: List[str],
    lang_tok_style: str,
    langtoks_specs: Sequence[str] = (LangTokSpec.main.value, ),
    extra_data: Optional[Dict[str, str]] = None,
) -> None:
    for spec in langtoks_specs:
        for language in language_list:
            dictionary.add_symbol(
                get_lang_tok(lang=language,
                             lang_tok_style=lang_tok_style,
                             spec=spec))

    if lang_tok_style == LangTokStyle.mbart.value or (
            extra_data is not None
            and LangTokSpec.mono_dae.value in extra_data):
        dictionary.add_symbol("<mask>")
    dictionary.pad_to_multiple_(8)