def create_lang_dictionary(cls, langs):
     unk = "<unk>"
     # hack to remove symbols other than unk as they are not needed by lang dict
     lang_dict = Dictionary(pad=unk, eos=unk, unk=unk, bos=unk)
     for lang in langs:
         lang_dict.add_symbol(lang)
     return lang_dict
예제 #2
0
 def setup_task(cls, args, **kwargs):
     """Setup the task. """
     dictionary = Dictionary()
     for i in range(args.dict_size):
         dictionary.add_symbol("word{}".format(i))
     logger.info("dictionary: {} types".format(len(dictionary)))
     return cls(args, dictionary)
예제 #3
0
    def setup_task(cls, args, **kwargs):
        """Setup the task. """
        dictionary = Dictionary()
        for i in range(args.dict_size):
            dictionary.add_symbol("word{}".format(i))
        logger.info("dictionary: {} types".format(len(dictionary)))

        args.max_source_positions = args.src_len + dictionary.pad() + 2
        args.max_target_positions = args.tgt_len + dictionary.pad() + 2

        return cls(args, dictionary)
예제 #4
0
class DummyLMTask(FairseqTask):
    def __init__(self, cfg: DummyLMConfig):
        super().__init__(cfg)

        # load dictionary
        self.dictionary = Dictionary()
        for i in range(cfg.dict_size):
            self.dictionary.add_symbol("word{}".format(i))
        self.dictionary.pad_to_multiple_(8)  # often faster if divisible by 8
        logger.info("dictionary: {} types".format(len(self.dictionary)))

        seq = torch.arange(cfg.tokens_per_sample +
                           1) + self.dictionary.pad() + 1

        self.dummy_src = seq[:-1]
        self.dummy_tgt = seq[1:]

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if self.cfg.batch_size is not None:
            bsz = self.cfg.batch_size
        else:
            bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample)
        self.datasets[split] = DummyDataset(
            {
                "id": 1,
                "net_input": {
                    "src_tokens":
                    torch.stack([self.dummy_src for _ in range(bsz)]),
                    "src_lengths":
                    torch.full(
                        (bsz, ), self.cfg.tokens_per_sample, dtype=torch.long),
                },
                "target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
                "nsentences": bsz,
                "ntokens": bsz * self.cfg.tokens_per_sample,
            },
            num_items=self.cfg.dataset_size,
            item_size=self.cfg.tokens_per_sample,
        )

    @property
    def source_dictionary(self):
        return self.dictionary

    @property
    def target_dictionary(self):
        return self.dictionary
예제 #5
0
def augment_dictionary(
    dictionary: Dictionary,
    language_list: List[str],
    lang_tok_style: str,
    langtoks_specs: Sequence[str] = (LangTokSpec.main.value, ),
    extra_data: Optional[Dict[str, str]] = None,
) -> None:
    for spec in langtoks_specs:
        for language in language_list:
            dictionary.add_symbol(
                get_lang_tok(lang=language,
                             lang_tok_style=lang_tok_style,
                             spec=spec))

    if lang_tok_style == LangTokStyle.mbart.value or (
            extra_data is not None
            and LangTokSpec.mono_dae.value in extra_data):
        dictionary.add_symbol("<mask>")