def load_dictionary(cls, filename):
        """Load the masked LM dictionary from the filename

        Args:
            filename (str): the filename
        """
        return MaskedLMDictionary.load(filename)
예제 #2
0
    def prepare(cls, args, **kargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        if args.lang_pairs is None:
            raise ValueError('--lang-pairs is required. List all the language pairs in the training objective.')
        args.lang_pairs = args.lang_pairs.split(',')
        sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        for lang in sorted_langs:
            paths = args.data.split(':')
            assert len(paths) > 0
            dicts[lang] = MaskedLMDictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return dicts, training
예제 #3
0
    def setup_task(cls, args, **kwargs):
        """Setup the task.
        """
        dictionary = MaskedLMDictionary.load(os.path.join(args.data, 'dict.txt'))

        print('| dictionary: {} types'.format(len(dictionary)))

        return cls(args, dictionary)
예제 #4
0
 def _read_fairseq_vocab(
     self, vocab_file: str, max_vocab: int = -1, min_count: int = -1
 ) -> Tuple[List, List, Dict]:
     dictionary = MaskedLMDictionary.load(vocab_file)
     dictionary.finalize(threshold=min_count, nwords=max_vocab, padding_factor=1)
     vocab_list = dictionary.symbols
     counts = dictionary.count
     replacements = {"<pad>": PAD, "</s>": EOS, "<unk>": UNK, "<mask>": MASK}
     return vocab_list, counts, replacements
예제 #5
0
 def setup_task(cls, args, **kwargs):
     """Setup the task."""
     dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt"))
     logger.info("dictionary: {} types".format(len(dictionary)))
     return cls(args, dictionary)
예제 #6
0
 def load_dictionary(cls, filename):
     return MaskedLMDictionary.load(filename)
예제 #7
0
    def prepare(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        s = args.word_mask_keep_rand.split(',')
        s = [float(x) for x in s]
        setattr(args, 'pred_probs', torch.FloatTensor([s[0], s[1], s[2]]))
        
        args.langs = sorted(args.langs.split(','))
        args.source_langs = sorted(args.source_langs.split(','))
        args.target_langs = sorted(args.target_langs.split(','))

        for lang in args.source_langs:
            assert lang in args.langs
        for lang in args.target_langs:
            assert lang in args.langs
        
        args.mass_steps = [s for s in args.mass_steps.split(',') if len(s) > 0]
        args.mt_steps   = [s for s in args.mt_steps.split(',')   if len(s) > 0]
        args.memt_steps = [s for s in args.memt_steps.split(',') if len(s) > 0]

        mono_langs = [
            lang_pair.split('-')[0]
            for lang_pair in args.mass_steps
            if len(lang_pair) > 0
        ]
        
        mono_lang_pairs = []
        for lang in mono_langs:
            mono_lang_pairs.append('{}-{}'.format(lang, lang))
        setattr(args, 'mono_lang_pairs', mono_lang_pairs)

        args.para_lang_pairs = list(set([
            '-'.join(sorted(lang_pair.split('-')))
            for lang_pair in set(args.mt_steps + args.memt_steps) if 
            len(lang_pair) > 0
        ]))

        args.valid_lang_pairs = [s for s in args.valid_lang_pairs.split(',') if len(s) > 0]

        for lang_pair in args.mono_lang_pairs:
            src, tgt = lang_pair.split('-')
            assert src in args.source_langs and tgt in args.target_langs

        for lang_pair in args.mt_steps + args.memt_steps:
            src, tgt = lang_pair.split('-')
            assert src in args.source_langs and tgt in args.target_langs

        for lang_pair in args.valid_lang_pairs:
            src, tgt = lang_pair.split('-')
            assert src in args.source_langs and tgt in args.target_langs

        if args.source_lang is not None:
            assert args.source_lang in args.source_langs

        if args.target_lang is not None:
            assert args.target_lang in args.target_langs

        langs_id = {}
        ids_lang = {}
        for i, v in enumerate(args.langs):
            langs_id[v] = i
            ids_lang[i] = v
        setattr(args, 'langs_id', langs_id)
        setattr(args, 'ids_lang', ids_lang)

        # If provide source_lang and target_lang, we will switch to translation
        if args.source_lang is not None and args.target_lang is not None:
            setattr(args, 'eval_lang_pair', '{}-{}'.format(args.source_lang, args.target_lang))
            training = False
        else:
            if len(args.para_lang_pairs) > 0:
                required_para = [s for s in set(args.mt_steps + args.memt_steps)]
                setattr(args, 'eval_lang_pair', required_para[0])
            else:
                setattr(args, 'eval_lang_pair', args.mono_lang_pairs[0])
            training = True
        setattr(args, 'n_lang', len(langs_id))
        setattr(args, 'eval_para', True if len(args.para_lang_pairs) > 0 else False)

        dicts = OrderedDict()
        for lang in args.langs:
            dicts[lang] = MaskedLMDictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[args.langs[0]].pad()
                assert dicts[lang].eos() == dicts[args.langs[0]].eos()
                assert dicts[lang].unk() == dicts[args.langs[0]].unk()
                assert dicts[lang].mask() == dicts[args.langs[0]].mask()
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return dicts, training