Пример #1
0
 def load_dictionary_and_postproc(path):
     d = load_dictionary(path)
     augment_dictionary(
         dictionary=d,
         language_list=language_list,
         lang_tok_style=args.lang_tok_style,
         langtoks_specs=args.langtoks_specs,
         extra_data=args.extra_data,
     )
     return d
Пример #2
0
    def prepare(cls, load_dictionary, args, **kargs):
        args.left_pad_source = utils.eval_bool(args.left_pad_source)
        args.left_pad_target = utils.eval_bool(args.left_pad_target)

        if not hasattr(args, "shuffle_instance"):
            args.shuffle_instance = False
        if args.langtoks is None:
            args.langtoks = {}
        if "main" not in args.langtoks:
            src_langtok_spec = args.encoder_langtok if args.encoder_langtok else None
            tgt_langtok_spec = "tgt" if args.decoder_langtok else None
            args.langtoks["main"] = (src_langtok_spec, tgt_langtok_spec)

        def check_langs(langs, pairs):
            messages = []
            for src, tgt in pairs:
                if src not in langs or tgt not in langs:
                    messages.append(
                        f"language pair {src}-{tgt} contains languages "
                        "that are not in the language dictionary"
                    )
            if len(messages) > 0:
                raise ValueError(" ".join(messages) + f"; langs: {langs}")

        if args.lang_pairs is None:
            raise ValueError(
                "--lang-pairs is required. List all the language pairs in the training objective."
            )
        if isinstance(args.lang_pairs, str):
            args.lang_pairs = args.lang_pairs.split(",")
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True
        language_list = cls.load_langs(args, **kargs)
        check_langs(
            language_list,
            (
                [p.split("-") for p in args.lang_pairs]
                if training
                else [(args.source_lang, args.target_lang)]
            ),
        )

        # load dictionaries
        if training:
            extra_lang_pairs = (
                list(
                    {p for _, v in args.extra_lang_pairs.items() for p in v.split(",")}
                )
                if args.extra_lang_pairs
                else []
            )
            langs_to_load_dicts = sorted(
                {x for p in args.lang_pairs + extra_lang_pairs for x in p.split("-")}
            )
        else:
            langs_to_load_dicts = sorted([args.source_lang, args.target_lang])

        dicts = OrderedDict()
        paths = utils.split_paths(args.data)
        assert len(paths) > 0
        for lang in langs_to_load_dicts:
            dicts[lang] = load_dictionary(
                os.path.join(paths[0], "dict.{}.txt".format(lang))
            )
            augment_dictionary(
                dictionary=dicts[lang],
                language_list=language_list,
                lang_tok_style=args.lang_tok_style,
                langtoks_specs=args.langtoks_specs,
                extra_data=args.extra_data,
            )
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[langs_to_load_dicts[0]].pad()
                assert dicts[lang].eos() == dicts[langs_to_load_dicts[0]].eos()
                assert dicts[lang].unk() == dicts[langs_to_load_dicts[0]].unk()
            logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang])))
        return language_list, dicts, training