Esempio n. 1
0
 def build_dictionary(cls,
                      filenames,
                      workers=1,
                      threshold=-1,
                      nwords=-1,
                      padding_factor=8):
     d = MaskedLMDictionary()
     for filename in filenames:
         Dictionary.add_file_to_dictionary(filename, d,
                                           tokenizer.tokenize_line, workers)
     d.finalize(threshold=threshold,
                nwords=nwords,
                padding_factor=padding_factor)
     return d
Esempio n. 2
0
    def load_dictionary(cls, filename):
        """Load the masked LM dictionary from the filename

        Args:
            filename (str): the filename
        """
        return MaskedLMDictionary.load(filename)
Esempio n. 3
0
    def setup_task(cls, args, **kwargs):
        """Setup the task.
        """
        dictionary = MaskedLMDictionary.load(
            os.path.join(args.data, 'dict.%s.txt' % args.monolingual_lang))
        print('| dictionary: {} types'.format(len(dictionary)))

        return cls(args, dictionary)
Esempio n. 4
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        if getattr(args, 'raw_text', False):
            utils.deprecation_warning(
                '--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning(
                '--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        paths = args.data.split(':')
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                paths[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = MaskedLMDictionary.load(
            os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = MaskedLMDictionary.load(
            os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        assert src_dict.mask() == tgt_dict.mask()

        print('| [{}] dictionary: {} types'.format(args.source_lang,
                                                   len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang,
                                                   len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
Esempio n. 5
0
 def _read_fairseq_vocab(self,
                         vocab_file: str,
                         max_vocab: int = -1,
                         min_count: int = -1) -> Tuple[List, List, Dict]:
     dictionary = MaskedLMDictionary.load(vocab_file)
     dictionary.finalize(threshold=min_count,
                         nwords=max_vocab,
                         padding_factor=1)
     vocab_list = dictionary.symbols
     counts = dictionary.count
     replacements = {
         "<pad>": PAD,
         "</s>": EOS,
         "<unk>": UNK,
         "<mask>": MASK
     }
     return vocab_list, counts, replacements
Esempio n. 6
0
 def load_dictionary(cls, filename):
     return MaskedLMDictionary.load(filename)
Esempio n. 7
0
    def prepare(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        s = args.word_mask_keep_rand.split(',')
        s = [float(x) for x in s]
        setattr(args, 'pred_probs', torch.FloatTensor([s[0], s[1], s[2]]))

        args.langs = sorted(args.langs.split(','))
        args.source_langs = sorted(args.source_langs.split(','))
        args.target_langs = sorted(args.target_langs.split(','))

        for lang in args.source_langs:
            assert lang in args.langs
        for lang in args.target_langs:
            assert lang in args.langs

        args.mass_steps = [s for s in args.mass_steps.split(',') if len(s) > 0]
        args.mt_steps = [s for s in args.mt_steps.split(',') if len(s) > 0]
        args.memt_steps = [s for s in args.memt_steps.split(',') if len(s) > 0]

        mono_langs = [
            lang_pair.split('-')[0] for lang_pair in args.mass_steps
            if len(lang_pair) > 0
        ]

        mono_lang_pairs = []
        for lang in mono_langs:
            mono_lang_pairs.append('{}-{}'.format(lang, lang))
        setattr(args, 'mono_lang_pairs', mono_lang_pairs)

        args.para_lang_pairs = list(
            set([
                '-'.join(sorted(lang_pair.split('-')))
                for lang_pair in set(args.mt_steps + args.memt_steps)
                if len(lang_pair) > 0
            ]))

        args.valid_lang_pairs = [
            s for s in args.valid_lang_pairs.split(',') if len(s) > 0
        ]

        for lang_pair in args.mono_lang_pairs:
            src, tgt = lang_pair.split('-')
            assert src in args.source_langs and tgt in args.target_langs

        for lang_pair in args.mt_steps + args.memt_steps:
            src, tgt = lang_pair.split('-')
            assert src in args.source_langs and tgt in args.target_langs

        for lang_pair in args.valid_lang_pairs:
            src, tgt = lang_pair.split('-')
            assert src in args.source_langs and tgt in args.target_langs

        if args.source_lang is not None:
            assert args.source_lang in args.source_langs

        if args.target_lang is not None:
            assert args.target_lang in args.target_langs

        langs_id = {}
        ids_lang = {}
        for i, v in enumerate(args.langs):
            langs_id[v] = i
            ids_lang[i] = v
        setattr(args, 'langs_id', langs_id)
        setattr(args, 'ids_lang', ids_lang)

        # If provide source_lang and target_lang, we will switch to translation
        if args.source_lang is not None and args.target_lang is not None:
            setattr(args, 'eval_lang_pair',
                    '{}-{}'.format(args.source_lang, args.target_lang))
            training = False
        else:
            if len(args.para_lang_pairs) > 0:
                required_para = [
                    s for s in set(args.mt_steps + args.memt_steps)
                ]
                setattr(args, 'eval_lang_pair', required_para[0])
            else:
                setattr(args, 'eval_lang_pair', args.mono_lang_pairs[0])
            training = True
        setattr(args, 'n_lang', len(langs_id))
        setattr(args, 'eval_para',
                True if len(args.para_lang_pairs) > 0 else False)

        dicts = OrderedDict()
        for lang in args.langs:
            dicts[lang] = MaskedLMDictionary.load(
                os.path.join(args.data, 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[args.langs[0]].pad()
                assert dicts[lang].eos() == dicts[args.langs[0]].eos()
                assert dicts[lang].unk() == dicts[args.langs[0]].unk()
                assert dicts[lang].mask() == dicts[args.langs[0]].mask()
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return dicts, training
Esempio n. 8
0
def _lang_token_index(dic: MaskedLMDictionary, lang: str):
    """Return language token index."""
    idx = dic.index(_lang_token(lang))
    assert idx != dic.unk_index, \
        f'cannot find language token for lang {lang}'
    return idx