예제 #1
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
예제 #2
0
파일: utils.py 프로젝트: fyabc/fairseq
def dummy_dictionary(vocab_size, prefix='token_'):
    d = Dictionary()
    for i in range(vocab_size):
        token = prefix + str(i)
        d.add_symbol(token)
    d.finalize(padding_factor=1)  # don't add extra padding symbols
    return d
예제 #3
0
    def load_dataset(self, split, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        if split == 'train':
            path = 'data/train_960_splited'
        else:
            path = 'data/dev_clean_splited'
        manifest = os.path.join(path, "cmvn_by_len_2.scp")
        self.datasets[split] = KaldiFileDataset(
            manifest,
            sample_rate=self.args.sample_rate,
            max_sample_size=self.args.max_sample_size,
            min_sample_size=self.args.min_sample_size
            if not self.args.no_min_cropping else self.args.max_sample_size,
            min_length=self.args.min_sample_size,
            pad=self.args.labels is not None or self.args.enable_padding,
            normalize=self.args.normalize,
        )

        if self.args.labels:
            dict_path = os.path.join(self.args.data,
                                     f"dict.{self.args.labels}.txt")
            self._target_dictionary = Dictionary.load(dict_path)
            label_path = os.path.join(self.args.data,
                                      f"{split}.{self.args.labels}")
            labels = []
            with open(label_path, "r") as f:
                for line in f:
                    labels.append(line)

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=not self.is_ctc,
            )
예제 #4
0
    def test_finalize(self):
        txt = [
            'A B C D',
            'B C D',
            'C D',
            'D',
        ]
        ref_ids1 = list(map(torch.IntTensor, [
            [4, 5, 6, 7, 2],
            [5, 6, 7, 2],
            [6, 7, 2],
            [7, 2],
        ]))
        ref_ids2 = list(map(torch.IntTensor, [
            [7, 6, 5, 4, 2],
            [6, 5, 4, 2],
            [5, 4, 2],
            [4, 2],
        ]))

        # build dictionary
        d = Dictionary()
        for line in txt:
            Tokenizer.tokenize(line, d, add_if_not_exist=True)

        def get_ids(dictionary):
            ids = []
            for line in txt:
                ids.append(Tokenizer.tokenize(line, dictionary, add_if_not_exist=False))
            return ids

        def assertMatch(ids, ref_ids):
            for toks, ref_toks in zip(ids, ref_ids):
                self.assertEqual(toks.size(), ref_toks.size())
                self.assertEqual(0, (toks != ref_toks).sum().item())

        ids = get_ids(d)
        assertMatch(ids, ref_ids1)

        # check finalized dictionary
        d.finalize()
        finalized_ids = get_ids(d)
        assertMatch(finalized_ids, ref_ids2)

        # write to disk and reload
        with tempfile.NamedTemporaryFile(mode='w') as tmp_dict:
            d.save(tmp_dict.name)
            d = Dictionary.load(tmp_dict.name)
            reload_ids = get_ids(d)
            assertMatch(reload_ids, ref_ids2)
            assertMatch(finalized_ids, reload_ids)
예제 #5
0
    def prepare(cls, args, **kargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning(
                '--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning(
                '--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        if args.lang_pairs is None:
            raise ValueError(
                '--lang-pairs is required. List all the language pairs in the training objective.'
            )
        args.lang_pairs = args.lang_pairs.split(',')
        sorted_langs = sorted(
            list({
                x
                for lang_pair in args.lang_pairs for x in lang_pair.split('-')
            }))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        for lang in sorted_langs:
            paths = args.data.split(':')
            assert len(paths) > 0
            dicts[lang] = Dictionary.load(
                os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return dicts, training
예제 #6
0
    def __init__(self, tokenizer_path):
        super().__init__()
        self.dict = Dictionary.load(os.path.join(tokenizer_path, 'dict.txt'))
        # <sep> and <pad> already exist in the dictionary
        self.index_special_tokens = {
            tok: self.dict.add_symbol(tok)
            for tok in special_tokens
        }

        args = Namespace(bpe='sentencepiece',
                         sample_break_mode='complete',
                         sentencepiece_vocab=os.path.join(
                             tokenizer_path, 'sentencepiece.bpe.model'))
        self.bpe = encoders.build_bpe(args)

        # this is useful for determining the device
        self.register_buffer('_float_tensor',
                             torch.tensor([0], dtype=torch.float))
        self.info = 'fairseq'
예제 #7
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning(
                '--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning(
                '--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        dictionary = None
        output_dictionary = None
        if args.data:
            paths = args.data.split(':')
            assert len(paths) > 0
            dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
            print('| dictionary: {} types'.format(len(dictionary)))
            output_dictionary = dictionary
            if args.output_dictionary_size >= 0:
                output_dictionary = TruncatedDictionary(
                    dictionary, args.output_dictionary_size)

        # upgrade old checkpoints
        if hasattr(args, 'exclude_self_target'):
            args.self_target = not args.exclude_self_target

        targets = []
        if getattr(args, 'self_target', False):
            targets.append('self')
        if getattr(args, 'future_target', False):
            targets.append('future')
        if getattr(args, 'past_target', False):
            targets.append('past')
        if len(targets) == 0:
            # standard language modeling
            targets = ['future']

        return cls(args, dictionary, output_dictionary, targets=targets)
예제 #8
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        if getattr(args, "raw_text", False):
            utils.deprecation_warning(
                "--raw-text is deprecated, please use --dataset-impl=raw")
            args.dataset_impl = "raw"
        elif getattr(args, "lazy_load", False):
            utils.deprecation_warning(
                "--lazy-load is deprecated, please use --dataset-impl=lazy")
            args.dataset_impl = "lazy"

        dictionary = None
        output_dictionary = None
        if args.data:
            paths = args.data.split(":")
            assert len(paths) > 0
            dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
            print("| dictionary: {} types".format(len(dictionary)))
            output_dictionary = dictionary
            if args.output_dictionary_size >= 0:
                output_dictionary = TruncatedDictionary(
                    dictionary, args.output_dictionary_size)

        # upgrade old checkpoints
        if hasattr(args, "exclude_self_target"):
            args.self_target = not args.exclude_self_target

        targets = []
        if getattr(args, "self_target", False):
            targets.append("self")
        if getattr(args, "future_target", False):
            targets.append("future")
        if getattr(args, "past_target", False):
            targets.append("past")
        if len(targets) == 0:
            # standard language modeling
            targets = ["future"]

        return cls(args, dictionary, output_dictionary, targets=targets)
예제 #9
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries)."""
        if args.target_lang is None:
            dict_basename = "dict.txt"
        else:
            dict_basename = "dict.{}.txt".format(args.target_lang)
        dict_path = os.path.join(args.data.split(os.pathsep)[0], dict_basename)
        if not os.path.isfile(dict_path):
            raise FileNotFoundError("Dict not found: {}".format(dict_path))
        tgt_dict = Dictionary.load(dict_path)

        if args.criterion == "ctc_loss":
            tgt_dict.add_symbol("<ctc_blank>")
        elif args.criterion == "asg_loss":
            for i in range(1, args.max_replabel + 1):
                tgt_dict.add_symbol(replabel_symbol(i))

        print("| dictionary: {} types".format(len(tgt_dict)))
        return cls(args, tgt_dict)
예제 #10
0
    def load_pretrained_model(path,
                              src_dict_path,
                              tgt_dict_path,
                              arg_overrides=None):
        model = utils.load_checkpoint_to_cpu(path)
        args = model['args']
        state_dict = model['model']
        args = utils.override_model_args(args, arg_overrides)
        src_dict = BertBasedDictionary(args.bert_name)
        tgt_dict = Dictionary.load(tgt_dict_path)
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()

        task = BertTranslationTask(args, src_dict, tgt_dict)
        model = task.build_model(args)
        model.upgrade_state_dict(state_dict)
        model.load_state_dict(state_dict, strict=True)
        return model
예제 #11
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    dictionary = Dictionary.load(args.dict) if args.dict is not None else None
    dataset = data_utils.load_indexed_dataset(
        args.input,
        dictionary,
        dataset_impl=args.dataset_impl,
        default="lazy",
    )

    for tensor_line in dataset:
        if dictionary is None:
            line = " ".join([str(int(x)) for x in tensor_line])
        else:
            line = dictionary.string(tensor_line)

        print(line)
예제 #12
0
    def __init__(self,
                 tgt_dict,
                 src_dict,
                 cl_ratio,
                 bpe_symbol='@@ ',
                 args=None):
        self.tgt_dict = tgt_dict
        self.src_dict = src_dict
        self.bpe_symbol = bpe_symbol
        self.cl_ratio = cl_ratio

        if args.cl_file == "all":
            model = torch.load('simile-mrt/cl_sim/model.de.lc.100_4_50000.pt',
                               map_location='cpu')
        elif args.cl_file == "wmt":
            model = torch.load(
                'simile-mrt/cl_sim/model.wmt.all.lc.100.0.0_25.pt',
                map_location='cpu')

        state_dict = model['state_dict']
        vocab_words = model['vocab']
        sim_args = model['args']

        #turn off gpu
        sim_args.gpu = -1

        if args.cl_file == "all":
            self.model = WordAveraging(
                sim_args,
                vocab_words,
                sp_file="simile-mrt/cl_sim/all.de.lc.sp.50k.model")
        elif args.cl_file == "wmt":
            self.model = WordAveraging(
                sim_args,
                vocab_words,
                sp_file="simile-mrt/cl_sim/wmt.all.lc.sp.50k.model")
        self.model.load_state_dict(state_dict, strict=True)
        # use a fresh Dictionary for scoring, so that we can add new elements
        self.scoring_dict = Dictionary()
        self.detok = MosesDetokenizer('en')

        self.lower_case = sim_args.lower_case
 def build_dictionary(filenames, src=False, tgt=False):
     assert src ^ tgt
     workers = args.workers
     threshold = args.thresholdsrc if src else args.thresholdtgt
     nwords = args.nwordssrc if src else args.nwordstgt
     padding_factor = args.padding_factor
     d = Dictionary()
     for filename in filenames:
         Dictionary.add_file_to_dictionary(filename, d,
                                           tokenizer.tokenize_line, workers,
                                           args.L)
     d.finalize(threshold=threshold,
                nwords=nwords,
                padding_factor=padding_factor)
     return d
예제 #14
0
    def assert_word_shuffle_matches_expected(
        self,
        x,
        x_len,
        max_shuffle_distance: int,
        vocab: Dictionary,
        expected_shufle_maps: List[Dict[int, int]],
        expect_eos_at_end: bool,
    ):
        """
        This verifies that with a given x, x_len, max_shuffle_distance, and
        vocab, we get the expected shuffle result.

        Args:
            x: Tensor of shape (T x B) = (sequence_length, batch_size)
            x_len: Tensor of length B = batch_size
            max_shuffle_distance: arg to pass to noising
            expected_shuffle_maps: List[mapping] where mapping is a
                Dict[old_index, new_index], mapping x's elements from their
                old positions in x to their new positions in x.
            expect_eos_at_end: if True, check the output to make sure there is
                an EOS at the end.
        """
        with data_utils.numpy_seed(1234):
            word_shuffle = noising.WordShuffle(vocab)
            x_noised, l_noised = word_shuffle.noising(
                x, x_len, max_shuffle_distance=max_shuffle_distance)

        # For every example, we have a different expected shuffle map. We check
        # that each example is shuffled as expected according to each
        # corresponding shuffle map.
        for i in range(len(expected_shufle_maps)):
            shuffle_map = expected_shufle_maps[i]
            for k, v in shuffle_map.items():
                self.assertEqual(x[k][i], x_noised[v][i])

        # Shuffling should not affect the length of each example
        for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised):
            self.assertEqual(pre_shuffle_length, post_shuffle_length)
        if expect_eos_at_end:
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
예제 #15
0
    def __init__(
        self,
        model_path: str,
        dict_path: str,
        device: str,
        lang: str = "en",
        vad_model: VoiceActivityDetection = None,
    ) -> None:
        self.SAMPLE_RATE = 16000
        self.MINIMUM_INPUT_LENGTH = 1024

        self.target_dict = Dictionary.load(dict_path)                   # target_dict: dict loaded from /home/kris/.pororo/misc/ko.ltr.txt; length: 108 (104 vocabs + 4 more added (bos="<s>", pad="<pad>", eos="</s>", unk="<unk>"))

        self.lang = lang
        self.graphemes = BrainWav2Vec2Recognizer.graphemes[lang]        # None if en or zh
        self.device = device

        self.collate_fn = collate_fn                                    # merges a list of samples to form a mini-batch of Tensor(s) (returns tensor 'inputs' and int tensor 'input_lengths')
        self.model = self._load_model(model_path, device, self.target_dict)     # Wav2VecCTC model; model = BrainWav2VecCtc (made via BrainWav2VecCtc.build_model (w/ pretrained weights from model_path: (/home/kris/.pororo/misc/wav2vec.ko.pt)))
        self.generator = W2lViterbiDecoder(self.target_dict)
        self.vad_model = vad_model
    def setup_dictionary(cls, args, **kwargs):
        dictionary = None
        output_dictionary = None
        if args.data:
            paths = utils.split_paths(args.data)
            assert len(paths) > 0
            dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
            if args.add_bos_token:
                languages, _ = cls._get_langs(args)
                logger.info("----------------")
                for lang in languages:
                    dictionary.add_symbol(lang_token(lang))
                    logger.info(f"add language token: {lang_token(lang)}")
                logger.info("----------------")

            logger.info("dictionary: {} types".format(len(dictionary)))
            output_dictionary = dictionary
            if args.output_dictionary_size >= 0:
                output_dictionary = TruncatedDictionary(
                    dictionary, args.output_dictionary_size)
        return (dictionary, output_dictionary)
예제 #17
0
    def __init__(
        self,
        model_path: str,
        dict_path: str,
        device: str,
        lang: str = "en",
        vad_model: VoiceActivityDetection = None,
    ) -> None:
        self.SAMPLE_RATE = 16000
        self.MINIMUM_INPUT_LENGTH = 1024

        self.target_dict = Dictionary.load(dict_path)

        self.lang = lang
        self.graphemes = BrainWav2Vec2Recognizer.graphemes[lang]
        self.device = device

        self.collate_fn = collate_fn
        self.model = self._load_model(model_path, device, self.target_dict)
        self.generator = W2lViterbiDecoder(self.target_dict)
        self.vad_model = vad_model
예제 #18
0
    def __init__(self, tgt_dict, bpe_symbol='@@ ', args=None):
        self.tgt_dict = tgt_dict
        self.bpe_symbol = bpe_symbol
        #self.scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
        model = torch.load('simile-mrt/sim/sim.pt', map_location='cpu')

        state_dict = model['state_dict']
        vocab_words = model['vocab_words']
        sim_args = model['args']

        #turn off gpu
        sim_args.gpu = -1

        self.model = WordAveraging(sim_args,
                                   vocab_words,
                                   sp_file="simile-mrt/sim/sim.sp.30k.model")
        self.model.load_state_dict(state_dict, strict=True)
        # use a fresh Dictionary for scoring, so that we can add new elements
        self.scoring_dict = Dictionary()
        self.detok = MosesDetokenizer('en')
        self.tok = TreebankWordTokenizer()
예제 #19
0
    def load_dataset(self, split, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        self.datasets[split] = FileHandwritingDataset(
            self.args.data,
            split=split,
            max_sample_size=self.args.max_sample_size,
            min_sample_size=self.args.max_sample_size,
            pad_to_multiples_of=self.args.pad_to_multiples_of,
            min_length=self.args.min_sample_size,
            pad=self.args.labels is not None or self.args.enable_padding,
            normalize=self.args.normalize,
        )

        if self.args.labels:
            assert False  ## TODO(JCh): we must load labels from scribblelens.
            dict_path = os.path.join(self.args.data,
                                     f"dict.{self.args.labels}.txt")
            self._target_dictionary = Dictionary.load(dict_path)
            label_path = os.path.join(self.args.data,
                                      f"{split}.{self.args.labels}")
            labels = []
            with open(label_path, "r") as f:
                for line in f:
                    labels.append(line)

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=not self.is_ctc,
            )
예제 #20
0
    def load_dataset(self, split, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        manifest = os.path.join(self.args.data, "{}.tsv".format(split))
        self.datasets[split] = FileAudioDataset(
            manifest,
            sample_rate=self.args.sample_rate,
            max_sample_size=self.args.max_sample_size,
            min_sample_size=self.args.max_sample_size,
            min_length=self.args.min_sample_size,
            pad=self.args.labels is not None or self.args.enable_padding,
            normalize=self.args.normalize,
        )

        if self.args.labels:
            dict_path = os.path.join(self.args.data,
                                     f"dict.{self.args.labels}.txt")
            self._target_dictionary = Dictionary.load(dict_path)
            label_path = os.path.join(self.args.data,
                                      f"{split}.{self.args.labels}")
            labels = []
            with open(label_path, "r") as f:
                for line in f:
                    labels.append(line)

            process_label = LabelEncoder(self.target_dictionary)

            self.datasets[split] = AddTargetDataset(
                self.datasets[split],
                labels,
                pad=self.target_dictionary.pad(),
                eos=self.target_dictionary.eos(),
                batch_targets=True,
                process_label=process_label,
                add_to_input=not self.is_ctc,
            )
    def load_dataset(self, split, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        manifest = os.path.join(self.args.data, "{}.tsv".format(split))

        self.datasets[split] = KDAudioDataset(
            manifest,
            sample_rate=self.args.sample_rate,
            max_sample_size=self.args.max_sample_size,
            min_sample_size=self.args.max_sample_size,
            min_length=self.args.min_sample_size,
            pad=True,  #self.args.labels is not None or self.args.enable_padding,
            normalize=self.args.normalize,
            feat_extension=self.args.feat_extension,
        )

        dict_path = os.path.join(self.args.data,
                                 f"dict.{self.args.labels}.txt")
        self._target_dictionary = Dictionary.load(dict_path)
예제 #22
0
    def test_add_file_to_dict(self):
        counts = {}
        num_lines = 100
        per_line = 10
        with tempfile.TemporaryDirectory("test_sampling") as data_dir:
            filename = os.path.join(data_dir, "dummy.txt")
            with open(filename, "w", encoding="utf-8") as data:
                for c in string.ascii_letters:
                    line = f"{c} " * per_line
                    for _ in range(num_lines):
                        data.write(f"{line}\n")
                    counts[c] = per_line * num_lines
                    per_line += 5

            dict = Dictionary()
            Dictionary.add_file_to_dictionary(filename, dict,
                                              tokenizer.tokenize_line, 10)
            dict.finalize(threshold=0, nwords=-1, padding_factor=8)

            for c in string.ascii_letters:
                count = dict.get_count(dict.index(c))
                self.assertEqual(
                    counts[c], count,
                    f"{c} count is {count} but should be {counts[c]}")
예제 #23
0
    def setup_task(cls, cfg: MultilingualDenoisingConfig, **kwargs):
        """Setup the task."""
        paths = cfg.data.split(":")
        assert len(paths) > 0
        dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))

        data_path = paths[0]
        if cfg.langs is None:
            languages = sorted([
                name for name in os.listdir(data_path)
                if os.path.isdir(os.path.join(data_path, name))
            ])
        else:
            languages = cfg.langs.split(",")

        if cfg.add_lang_token:
            for lang in languages:
                dictionary.add_symbol("[{}]".format(lang))

        logger.info("dictionary: {} types".format(len(dictionary)))
        if not hasattr(cfg, "shuffle_instance"):
            cfg.shuffle_instance = False
        return cls(cfg, dictionary)
예제 #24
0
    def prepare(cls, args, **kargs):
        cls.update_args(args)
        sorted_langs = sorted(
            list({
                x
                for lang_pair in args.lang_pairs for x in lang_pair.split("-")
            }))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        bert_dict_langs: Set = set(args.use_bert_dict.split(","))
        for lang in sorted_langs:
            paths = utils.split_paths(args.data)
            assert len(paths) > 0

            if lang in bert_dict_langs:
                logger.info("Use DirctionaryForBert for {}".format(lang))
                dicts[lang] = DictionaryForBert.load(
                    os.path.join(paths[0], "dict.{}.txt".format(lang)))
            else:
                logger.info("Use default Dirctionary for {}".format(lang))
                dicts[lang] = Dictionary.load(
                    os.path.join(paths[0], "dict.{}.txt".format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            logger.info("[{}] dictionary: {} types".format(
                lang, len(dicts[lang])))
        return dicts, training
예제 #25
0
    def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
        """Build the dictionary

        Args:
            filenames (list): list of filenames
            workers (int): number of concurrent workers
            threshold (int): defines the minimum word count
            nwords (int): defines the total number of words in the final dictionary,
                including special symbols
            padding_factor (int): can be used to pad the dictionary size to be a
                multiple of 8, which is important on some hardware (e.g., Nvidia
                Tensor Cores).
        """
        d = Dictionary()
        for filename in filenames:
            Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
        d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
        return d
예제 #26
0
    def prepare(cls, args, **kargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        if args.lang_pairs is None:
            raise ValueError(
                '--lang-pairs is required. List all the language pairs in the training objective.'
            )
        if isinstance(args.lang_pairs, str):
            args.lang_pairs = args.lang_pairs.split(',')
        sorted_langs = sorted(
            list({
                x
                for lang_pair in args.lang_pairs for x in lang_pair.split('-')
            }))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        for lang in sorted_langs:
            paths = args.data.split(os.pathsep)
            assert len(paths) > 0
            dicts[lang] = Dictionary.load(
                os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            logger.info('[{}] dictionary: {} types'.format(
                lang, len(dicts[lang])))
        return dicts, training
    def setup_task(cls, args, **kwargs):
        """Setup the task.
        """
        paths = args.data.split(':')
        assert len(paths) > 0
        dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))

        data_path = paths[0]
        if args.langs is None:
            languages = sorted([
                name for name in os.listdir(data_path)
                if os.path.isdir(os.path.join(data_path, name))
            ])
        else:
            languages = args.langs.split(',')

        if args.add_lang_token:
            for lang in languages:
                dictionary.add_symbol('[{}]'.format(lang))

        logger.info("dictionary: {} types".format(len(dictionary)))
        if not hasattr(args, 'shuffle_instance'):
            args.shuffle_instance = False
        return cls(args, dictionary)
예제 #28
0
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if not hasattr(args, 'audio_input'):
            args.audio_input = False

        args.lang_pairs = args.lang_pairs.split(',')
        if args.source_lang is not None or args.target_lang is not None:
            #if args.lang_pairs is not None:
            #    raise ValueError(
            #        '--source-lang/--target-lang implies generation, which is '
            #        'incompatible with --lang-pairs'
            #    )
            training = False
            #args.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
        else:
            training = True
            #args.lang_pairs = args.lang_pairs.split(',')
            args.source_lang, args.target_lang = args.lang_pairs[0].split('-')

        langs = list(
            {x
             for lang_pair in args.lang_pairs for x in lang_pair.split('-')})

        # load dictionaries
        dicts = OrderedDict()
        for lang in langs:
            dicts[lang] = Dictionary.load(
                os.path.join(args.data, 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[langs[0]].pad()
                assert dicts[lang].eos() == dicts[langs[0]].eos()
                assert dicts[lang].unk() == dicts[langs[0]].unk()
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))

        return cls(args, dicts, training)
예제 #29
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                args.data[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        # load dictionaries
        tgt_dict = Dictionary.load(
            os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
        print('| [{}] dictionary: {} types'.format(args.target_lang,
                                                   len(tgt_dict)))

        return cls(args, None, tgt_dict)
예제 #30
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        dictionary = None
        output_dictionary = None
        if args.data:
            dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
            print('| dictionary: {} types'.format(len(dictionary)))
            output_dictionary = ResourceManager.get_senses_dictionary(True)
            print('| output_dictionary: {} types'.format(
                len(output_dictionary)))
            criterion_weights = torch.ones(len(output_dictionary)).float()
            criterion_weights[:output_dictionary.nspecial] = 0.
            criterion_weights.requires_grad = False
        else:
            raise NotImplementedError

        return cls(args,
                   dictionary,
                   output_dictionary,
                   criterion_weights=criterion_weights)
예제 #31
0
def main():
    args = parser.parse_args()
    sample = dict()
    net_input = dict()

    feature = get_feature(args.wav_path)
    target_dict = Dictionary.load(args.target_dict_path)

    model = load_model(args.w2v_path, target_dict)
    model[0].eval()

    generator = W2lViterbiDecoder(target_dict)
    net_input["source"] = feature.unsqueeze(0)

    padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)

    net_input["padding_mask"] = padding_mask
    sample["net_input"] = net_input

    with torch.no_grad():
        hypo = generator.generate(model, sample, prefix_tokens=None)

    hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
    print(post_process(hyp_pieces, 'letter'))
예제 #32
0
def augment_dictionary(
    dictionary: Dictionary,
    language_list: List[str],
    lang_tok_style: str,
    langtoks_specs: Sequence[str] = (LangTokSpec.main.value, ),
    extra_data: Optional[Dict[str, str]] = None,
) -> None:
    for spec in langtoks_specs:
        for language in language_list:
            dictionary.add_symbol(
                get_lang_tok(lang=language,
                             lang_tok_style=lang_tok_style,
                             spec=spec))

    if lang_tok_style == LangTokStyle.mbart.value or (
            extra_data is not None
            and LangTokSpec.mono_dae.value in extra_data):
        dictionary.add_symbol("<mask>")
    dictionary.pad_to_multiple_(8)
예제 #33
0
class BERTweetTokenizer():
    def __init__(self, pretrained_path="../pretrained/bertweet/"):

        self.bpe = fastBPE(
            SimpleNamespace(
                bpe_codes=os.path.join(pretrained_path, "bpe.codes")))
        self.vocab = Dictionary()
        self.vocab.add_from_file(os.path.join(pretrained_path, "dict.txt"))
        self.cls_token_id = 0
        self.pad_token_id = 1
        self.sep_token_id = 2
        self.pad_token = '<pad>'
        self.cls_token = '<s>'
        self.sep_token = '</s>'

    def bpe_encode(self, text):
        return self.bpe.encode(text)

    def encode(self, text, add_special_tokens=False):
        subwords = self.bpe.encode(text)
        input_ids = self.vocab.encode_line(
            subwords, append_eos=False,
            add_if_not_exist=False).long().tolist()
        return input_ids

    def tokenize(self, text):
        return self.bpe_encode(text).split()

    def convert_tokens_to_ids(self, tokens):
        input_ids = self.vocab.encode_line(
            ' '.join(tokens), append_eos=False,
            add_if_not_exist=False).long().tolist()
        return input_ids

    #from: https://www.kaggle.com/nandhuelan/bertweet-first-look
    def decode_id(self, id):
        return self.vocab.string(id, bpe_symbol='@@')

    def decode_id_nospace(self, id):
        return self.vocab.string(id, bpe_symbol='@@ ')