Example #1
0
    def _load_dataset_split(self, split, epoch, combine):
        paths = utils.split_paths(self.cfg.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError("Dataset not found: {} ({})".format(
                split, split_path))

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.cfg.shorten_data_split_list,
            self.cfg.shorten_method,
            self.cfg.tokens_per_sample,
            self.cfg.seed,
        )

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.cfg.tokens_per_sample - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.cfg.sample_break_mode,
        )
        logger.info("loaded {} blocks from: {}".format(len(dataset),
                                                       split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        return PrependTokenDataset(dataset, self.source_dictionary.bos())
Example #2
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        dataset_map = OrderedDict()

        for lang in self.langs2id.keys():
            # Datasets are expected to be in "split.lang" format (Eg: train.en)
            language_split = "{}.{}".format(split, lang)

            block_dataset, sizes = self._load_single_lang_dataset(
                split=language_split, epoch=epoch
            )

            dataset_map[lang] = MaskedLMDataset(
                dataset=block_dataset,
                sizes=sizes,
                vocab=self.dictionary,
                pad_idx=self.dictionary.pad(),
                mask_idx=self.dictionary.mask(),
                classif_token_idx=self.dictionary.eos(),
                sep_token_idx=self.dictionary.eos(),
                shuffle=getattr(self.args, "shuffle", False),
                has_pairs=False,
                segment_id=self.langs2id[lang],
                seed=self.seed,
            )

        self.datasets[split] = MultiCorpusSampledDataset(dataset_map)
        logger.info(
            "{} {} {} examples".format(
                utils.split_paths(self.args.data)[epoch - 1],
                split,
                len(self.datasets[split]),
            )
        )
Example #3
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        self.datasets[split] = get_asr_dataset_from_json(
            data_path,
            split,
            self.tgt_dict,
            combine=combine,
            upsample_primary=self.args.upsample_primary,
            num_buckets=self.args.num_batch_buckets,
            seed=self.args.seed,
            specaugment_config=self.specaugment_config,
        )

        src_dataset = self.datasets[split].src
        if isinstance(src_dataset, ConcatDataset):
            self.feat_dim = src_dataset.datasets[0].feat_dim
        elif isinstance(src_dataset, BaseWrapperDataset):
            self.feat_dim = src_dataset.dataset.feat_dim
        else:
            self.feat_dim = src_dataset.feat_dim

        # update the counts of <eos> and <unk> in tgt_dict with training data
        if split == "train":
            tgt_dataset = self.datasets[split].tgt
            self.tgt_dict.count[self.tgt_dict.eos()] = len(tgt_dataset)
            unk_count = 0
            for i in range(len(tgt_dataset)):
                unk_count += (tgt_dataset[i][0] == self.tgt_dict.unk()
                              ).int().sum().item()
            self.tgt_dict.count[self.tgt_dict.unk()] = unk_count
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        paths = utils.split_paths(args.data)
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                paths[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        if args.max_source_positions is None:
            raise Exception('Please specify max source positions!')

        args.max_target_positions = args.max_source_positions  # set source and target to be same length

        # load dictionaries
        src_dict = cls.load_dictionary(
            os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = cls.load_dictionary(
            os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        logger.info('[{}] dictionary: {} types'.format(args.source_lang,
                                                       len(src_dict)))
        logger.info('[{}] dictionary: {} types'.format(args.target_lang,
                                                       len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
Example #5
0
    def setUp(self, cfg):
        if isinstance(cfg, Namespace):
            cfg = convert_namespace_to_omegaconf(cfg)

        self.task = tasks.setup_task(cfg.task)
        self.tgt_dict = self.task.target_dictionary

        # Load ensemble
        logger.info("loading model(s) from {}".format(cfg.common_eval.path))
        models, _ = checkpoint_utils.load_model_ensemble(
            utils.split_paths(cfg.common_eval.path),
            arg_overrides={},
            task=self.task,
            suffix=cfg.checkpoint.checkpoint_suffix,
            strict=False,
            num_shards=cfg.checkpoint.checkpoint_shard_count,
        )
        if len(models) > 1:
            raise Exception(
                "Currently loading multiple models is not supported")
        self.model = models[0]

        # Optimize model for generation
        if cfg.common.fp16:
            self.model.half()
        if self.use_cuda:
            self.model.cuda()
        self.model.prepare_for_inference_(cfg)

        self.generator = self.task.build_generator(
            [self.model],
            cfg.generation,
            extra_gen_cls_kwargs={},
        )
        # Handle tokenization and BPE
        self.tokenizer = self.task.build_tokenizer(cfg.tokenizer)
        self.bpe = self.task.build_bpe(cfg.bpe)
        self.remove_bpe = cfg.common_eval.post_process
Example #6
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.cfg.data)
        assert len(paths) > 0
        if split != self.cfg.train_subset:
            # if not training data set, use the first shard for valid and test
            paths = paths[:1]
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.cfg.source_lang, self.cfg.target_lang

        self.datasets[split] = load_langpair_dataset(
            data_path,
            split,
            src,
            self.src_dict,
            tgt,
            self.tgt_dict,
            combine=combine,
            dataset_impl=self.cfg.dataset_impl,
            upsample_primary=self.cfg.upsample_primary,
            left_pad_source=self.cfg.left_pad_source,
            left_pad_target=self.cfg.left_pad_target,
            max_source_positions=self.cfg.max_source_positions,
            max_target_positions=self.cfg.max_target_positions,
            load_alignments=self.cfg.load_alignments,
            truncate_source=self.cfg.truncate_source,
            num_buckets=self.cfg.num_batch_buckets,
            shuffle=(split != "test"),
            pad_to_multiple=self.cfg.required_seq_len_multiple,
            add_token_type_ids=self.cfg.dataset_add_token_type_ids,
            prepend_bos_to_src=self.cfg.prepend_bos_to_src,
        )
    def load_dataset(self, split, epoch=0, **kwargs):
        """Load a dataset split."""

        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split('-')
            langpair_dataset = load_langpair_dataset(
                data_path,
                split,
                src,
                self.dicts[src],
                tgt,
                self.dicts[tgt],
                combine=True,
                dataset_impl=self.args.dataset_impl,
                upsample_primary=self.args.upsample_primary,
                left_pad_source=self.args.left_pad_source,
                left_pad_target=self.args.left_pad_target,
                max_source_positions=self.args.max_source_positions,
                max_target_positions=self.args.max_target_positions,
            )
            return self.alter_dataset_langtok(
                langpair_dataset,
                src_eos=self.dicts[src].eos(),
                src_lang=src,
                tgt_eos=self.dicts[tgt].eos(),
                tgt_lang=tgt,
            )

        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict([(lang_pair, language_pair_dataset(lang_pair))
                         for lang_pair in self.lang_pairs]),
            eval_key=None if self.training else "%s-%s" %
            (self.args.source_lang, self.args.target_lang),
        )
    def _load_single_lang_dataset(self, split, epoch):
        loaded_datasets = []

        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[epoch % len(paths)]

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)

            ds = data_utils.load_indexed_dataset(path, self.dictionary, self.args.dataset_impl)
            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

            # Since we append each block with the classification_token,
            # we need to effectively create blocks of length
            # tokens_per_sample-1
            loaded_datasets.append(
                TokenBlockDataset(
                    ds, ds.sizes, self.args.tokens_per_sample - 1,
                    pad=self.dictionary.pad(), eos=self.dictionary.eos(),
                )
            )

            logger.info('{} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        return dataset, sizes
Example #9
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        dataset = CodeCompletionDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            split_fn=self.split_fn,
            shuffle=(split != 'test'),
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            append_eos_to_source=True,
            append_eos_to_target=True,
        )

        self.datasets[split] = dataset
        logger.info(
            "Split: {0}, Loaded {1} samples of CodeCompletionDataset".format(
                split,
                len(self.datasets[split]),
            ))
Example #10
0
    def __init__(self, args, common_dict, mono_langs, valid_lang_pairs):
        super().__init__(args, common_dict, common_dict)
        self.common_dict = common_dict
        self.mono_langs = mono_langs
        self.valid_lang_pairs = valid_lang_pairs

        self.SHOW_SAMPLES_INTERVAL = 1000
        # Start by showing samples
        self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL
        self.SHOW_SAMPLES_NUMBER = 5
        self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt)
        self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae)

        self.args = args
        self.data = utils.split_paths(self.args.data)
        if len(self.data) == 1:
            shards = list(Path(self.data[0]).glob("shard*"))
            if len(shards) > 0:
                # keep this as strings, since it can also be a manifold path
                old_data = self.data
                self.data = [str(shard) for shard in shards]
                logging.warning(
                    f"Expanded data directory {old_data} to {self.data}")
Example #11
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        src, tgt = self.args.source_lang, self.args.target_lang

        self.datasets[split] = load_langpair_dataset(
            data_path, split, src, self.src_dict, tgt, self.tgt_dict,
            combine=combine, dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            prepend_bos=True,
        )
Example #12
0
    def prepare(cls, args, **kargs):
        args.left_pad_source = utils.eval_bool(args.left_pad_source)
        args.left_pad_target = utils.eval_bool(args.left_pad_target)

        if args.lang_pairs is None:
            raise ValueError(
                '--lang-pairs is required. List all the language pairs in the training objective.'
            )
        if isinstance(args.lang_pairs, str):
            args.lang_pairs = args.lang_pairs.split(',')
        sorted_langs = sorted(
            list({
                x
                for lang_pair in args.lang_pairs for x in lang_pair.split('-')
            }))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        for lang in sorted_langs:
            paths = utils.split_paths(args.data)
            assert len(paths) > 0
            dicts[lang] = cls.load_dictionary(
                os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            logger.info('[{}] dictionary: {} types'.format(
                lang, len(dicts[lang])))
        return dicts, training
Example #13
0
    def prepare(cls, args, **kargs):
        cls.update_args(args)
        sorted_langs = sorted(
            list({
                x
                for lang_pair in args.lang_pairs for x in lang_pair.split("-")
            }))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        bert_dict_langs: Set = set(args.use_bert_dict.split(","))
        for lang in sorted_langs:
            paths = utils.split_paths(args.data)
            assert len(paths) > 0

            if lang in bert_dict_langs:
                logger.info("Use DirctionaryForBert for {}".format(lang))
                dicts[lang] = DictionaryForBert.load(
                    os.path.join(paths[0], "dict.{}.txt".format(lang)))
            else:
                logger.info("Use default Dirctionary for {}".format(lang))
                dicts[lang] = Dictionary.load(
                    os.path.join(paths[0], "dict.{}.txt".format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            logger.info("[{}] dictionary: {} types".format(
                lang, len(dicts[lang])))
        return dicts, training
Example #14
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        # infer langcode
        #src, tgt = self.args.source_lang, self.args.target_lang

        self.datasets[split] = load_ape_dataset(
            data_path,
            split,
            self.src_dict,
            self.tgt_dict,
            combine=combine,
            dataset_impl=self.args.dataset_impl,
            upsample_primary=self.args.upsample_primary,
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
            prepend_bos=True,
            input_type=self.args.input_type,
            src_type=self.args.src_type,
        )
Example #15
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        self.datasets[split] = get_asr_dataset_from_json(
            data_path,
            split,
            self.dictionary,
            combine=combine,
            upsample_primary=self.args.upsample_primary,
            num_buckets=self.args.num_batch_buckets,
            shuffle=(split != getattr(self.args, "gen_subset", "")),
            lf_mmi=(self.args.criterion == "lattice_free_mmi"),
            seed=self.args.seed,
            specaugment_config=self.specaugment_config,
            chunk_width=None if self.training_stage
            and split in self.args.valid_subset.split(",") else
            self.chunk_width,
            chunk_left_context=self.chunk_left_context,
            chunk_right_context=self.chunk_right_context,
            label_delay=self.label_delay,
        )

        src_dataset = self.datasets[split].src
        if isinstance(src_dataset, ConcatDataset):
            self.feat_dim = src_dataset.datasets[0].feat_dim
        elif isinstance(src_dataset, BaseWrapperDataset):
            self.feat_dim = src_dataset.dataset.feat_dim
        else:
            self.feat_dim = src_dataset.feat_dim
Example #16
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """

        # get padding...
        args.left_pad_source = utils.eval_bool(args.left_pad_source)
        args.left_pad_target = utils.eval_bool(args.left_pad_target)
        paths = utils.split_paths(args.data)
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                paths[0]
            )

        print("path:",os.path.join(paths[0], "/Dicts/dict.txt"))
        dictionary = cls.load_dictionary(
            os.path.join(paths[0]+"/Dicts/", "dict.txt")
        )

        return cls(args, dictionary,paths)
Example #17
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        dictionary = None
        output_dictionary = None
        if args.data:
            paths = utils.split_paths(args.data)
            assert len(paths) > 0
            dict_path = os.path.join(paths[0], "dict.txt") if args.dict is None \
                else args.dict
            dictionary = AsrDictionary.load(dict_path)
            logger.info("dictionary: {} types".format(len(dictionary)))
            output_dictionary = dictionary
            if args.output_dictionary_size >= 0:
                output_dictionary = TruncatedDictionary(
                    dictionary, args.output_dictionary_size)

        # upgrade old checkpoints
        if hasattr(args, "exclude_self_target"):
            args.self_target = not args.exclude_self_target

        targets = []
        if getattr(args, "self_target", False):
            targets.append("self")
        if getattr(args, "future_target", False):
            targets.append("future")
        if getattr(args, "past_target", False):
            targets.append("past")
        if len(targets) == 0:
            # standard language modeling
            targets = ["future"]

        return cls(args, dictionary, output_dictionary, targets=targets)
Example #18
0
def main(cfg: FairseqConfig):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    start_time = time.time()
    total_translate_time = 0

    utils.import_user_module(cfg.common)

    if cfg.interactive.buffer_size < 1:
        cfg.interactive.buffer_size = 1
    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.batch_size = 1

    assert (not cfg.generation.sampling
            or cfg.generation.nbest == cfg.generation.beam
            ), "--sampling requires --nbest to be equal to --beam"
    assert (not cfg.dataset.batch_size
            or cfg.dataset.batch_size <= cfg.interactive.buffer_size
            ), "--batch-size cannot be larger than --buffer-size"

    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(cfg.task)

    # Load ensemble
    overrides = ast.literal_eval(cfg.common_eval.model_overrides)
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Initialize generator
    generator = task.build_generator(models, cfg.generation)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(cfg.tokenizer)
    bpe = encoders.build_bpe(cfg.bpe)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models])

    if cfg.generation.constraints:
        logger.warning(
            "NOTE: Constrained decoding currently assumes a shared subword vocabulary."
        )

    if cfg.interactive.buffer_size > 1:
        logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info("Type the input sentence and press return:")
    start_id = 0
    for inputs in buffered_read(cfg.interactive.input,
                                cfg.interactive.buffer_size):
        results = []
        for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
            bsz = batch.src_tokens.size(0)
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            constraints = batch.constraints
            if use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()
                if constraints is not None:
                    constraints = constraints.cuda()

            sample = {
                "net_input": {
                    "src_tokens": src_tokens,
                    "src_lengths": src_lengths,
                },
            }
            translate_start_time = time.time()
            translations = task.inference_step(generator,
                                               models,
                                               sample,
                                               constraints=constraints)
            translate_time = time.time() - translate_start_time
            total_translate_time += translate_time
            list_constraints = [[] for _ in range(bsz)]
            if cfg.generation.constraints:
                list_constraints = [unpack_constraints(c) for c in constraints]
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                constraints = list_constraints[i]
                results.append((
                    start_id + id,
                    src_tokens_i,
                    hypos,
                    {
                        "constraints": constraints,
                        "time": translate_time / len(translations),
                    },
                ))

        # sort output to match input order
        for id_, src_tokens, hypos, info in sorted(results,
                                                   key=lambda x: x[0]):
            if src_dict is not None:
                src_str = src_dict.string(src_tokens,
                                          cfg.common_eval.post_process)
                print("S-{}\t{}".format(id_, src_str))
                print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
                for constraint in info["constraints"]:
                    print("C-{}\t{}".format(
                        id_,
                        tgt_dict.string(constraint,
                                        cfg.common_eval.post_process)))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=cfg.common_eval.post_process,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                detok_hypo_str = decode_fn(hypo_str)
                score = hypo["score"] / math.log(2)  # convert to base 2
                # original hypothesis (after tokenization and BPE)
                print("H-{}\t{}\t{}".format(id_, score, hypo_str))
                # detokenized hypothesis
                print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
                print("P-{}\t{}".format(
                    id_,
                    " ".join(
                        map(
                            lambda x: "{:.4f}".format(x),
                            # convert from base e to base 2
                            hypo["positional_scores"].div_(math.log(2)
                                                           ).tolist(),
                        )),
                ))
                if cfg.generation.print_alignment:
                    alignment_str = " ".join(
                        ["{}-{}".format(src, tgt) for src, tgt in alignment])
                    print("A-{}\t{}".format(id_, alignment_str))

        # update running id_ counter
        start_id += len(inputs)

    logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format(
        time.time() - start_time, total_translate_time))
Example #19
0
def main(args, task=None, model_state=None):
    check_args(args)

    if args.max_tokens is None and args.batch_size is None:
        args.max_tokens = 4000000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    logger.info("| decoding with criterion {}".format(args.criterion))

    task = tasks.setup_task(args)

    # Load ensemble
    if args.load_emissions:
        models, criterions = [], []
        task.load_dataset(args.gen_subset)
    else:
        logger.info("| loading model(s) from {}".format(args.path))
        models, saved_cfg = checkpoint_utils.load_model_ensemble(
            utils.split_paths(args.path),
            arg_overrides=ast.literal_eval(args.model_overrides),
            task=task,
            suffix=args.checkpoint_suffix,
            strict=(args.checkpoint_shard_count == 1),
            num_shards=args.checkpoint_shard_count,
            state=model_state,
        )
        optimize_models(args, use_cuda, models)
        task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)

    # Set dictionary
    tgt_dict = task.target_dictionary

    logger.info("| {} {} {} examples".format(
        args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        raise NotImplementedError("asg_loss is currently not supported")
        # trans = criterions[0].asg.trans.data
        # args.asg_transitions = torch.flatten(trans).tolist()

    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task, models)

    # Initialize generator
    gen_timer = StopwatchMeter()

    def build_generator(args):
        w2l_decoder = getattr(args, "w2l_decoder", None)
        if w2l_decoder == "viterbi":
            from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder

            return W2lViterbiDecoder(args, task.target_dictionary)
        elif w2l_decoder == "kenlm":
            from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder

            return W2lKenLMDecoder(args, task.target_dictionary)
        elif w2l_decoder == "fairseqlm":
            from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder

            return W2lFairseqLMDecoder(args, task.target_dictionary)
        else:
            print(
                "only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
            )

    # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
    generator = build_generator(args)

    if args.load_emissions:
        generator = ExistingEmissionsDecoder(
            generator, np.load(args.load_emissions, allow_pickle=True))
        logger.info("loaded emissions from " + args.load_emissions)

    num_sentences = 0

    if args.results_path is not None and not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    max_source_pos = (utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models]), )

    if max_source_pos is not None:
        max_source_pos = max_source_pos[0]
        if max_source_pos is not None:
            max_source_pos = max_source_pos[0] - 1

    if args.dump_emissions:
        emissions = {}
    if args.dump_features:
        features = {}
        models[0].bert.proj = None
    else:
        res_files = prepare_result_files(args)
    errs_t = 0
    lengths_t = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, :args.prefix_size]

            gen_timer.start()
            if args.dump_emissions:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    emm = models[0].get_normalized_probs(encoder_out,
                                                         log_probs=True)
                    emm = emm.transpose(0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        emissions[id.item()] = emm[i]
                    continue
            elif args.dump_features:
                with torch.no_grad():
                    encoder_out = models[0](**sample["net_input"])
                    feat = encoder_out["encoder_out"].transpose(
                        0, 1).cpu().numpy()
                    for i, id in enumerate(sample["id"]):
                        padding = (encoder_out["encoder_padding_mask"][i].cpu(
                        ).numpy() if encoder_out["encoder_padding_mask"]
                                   is not None else None)
                        features[id.item()] = (feat[i], padding)
                    continue
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample["id"].tolist()):
                speaker = None
                # id = task.dataset(args.gen_subset).ids[int(sample_id)]
                id = sample_id
                toks = (sample["target"][i, :] if "target_label" not in sample
                        else sample["target_label"][i, :])
                target_tokens = utils.strip_pad(toks,
                                                tgt_dict.pad()).int().cpu()
                # Process top predictions
                errs, length = process_predictions(
                    args,
                    hypos[i],
                    None,
                    tgt_dict,
                    target_tokens,
                    res_files,
                    speaker,
                    id,
                )
                errs_t += errs
                lengths_t += length

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += (sample["nsentences"] if "nsentences" in sample
                              else sample["id"].numel())

    wer = None
    if args.dump_emissions:
        emm_arr = []
        for i in range(len(emissions)):
            emm_arr.append(emissions[i])
        np.save(args.dump_emissions, emm_arr)
        logger.info(
            f"saved {len(emissions)} emissions to {args.dump_emissions}")
    elif args.dump_features:
        feat_arr = []
        for i in range(len(features)):
            feat_arr.append(features[i])
        np.save(args.dump_features, feat_arr)
        logger.info(f"saved {len(features)} emissions to {args.dump_features}")
    else:
        if lengths_t > 0:
            wer = errs_t * 100.0 / lengths_t
            logger.info(f"WER: {wer}")

        logger.info("| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
                    "sentences/s, {:.2f} tokens/s)".format(
                        num_sentences,
                        gen_timer.n,
                        gen_timer.sum,
                        num_sentences / gen_timer.sum,
                        1.0 / gen_timer.avg,
                    ))
        logger.info("| Generate {} with beam={}".format(
            args.gen_subset, args.beam))
    return task, wer
    def prepare(cls, load_dictionary, args, **kargs):
        args.left_pad_source = utils.eval_bool(args.left_pad_source)
        args.left_pad_target = utils.eval_bool(args.left_pad_target)

        if not hasattr(args, "shuffle_instance"):
            args.shuffle_instance = False
        if args.langtoks is None:
            args.langtoks = {}
        if "main" not in args.langtoks:
            src_langtok_spec = args.encoder_langtok if args.encoder_langtok else None
            tgt_langtok_spec = "tgt" if args.decoder_langtok else None
            args.langtoks["main"] = (src_langtok_spec, tgt_langtok_spec)

        def check_langs(langs, pairs):
            messages = []
            for src, tgt in pairs:
                if src not in langs or tgt not in langs:
                    messages.append(
                        f"language pair {src}-{tgt} contains languages "
                        "that are not in the language dictionary"
                    )
            if len(messages) > 0:
                raise ValueError(" ".join(messages) + f"; langs: {langs}")

        if args.lang_pairs is None:
            raise ValueError(
                "--lang-pairs is required. List all the language pairs in the training objective."
            )
        if isinstance(args.lang_pairs, str):
            args.lang_pairs = args.lang_pairs.split(",")
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True
        language_list = cls.load_langs(args, **kargs)
        check_langs(
            language_list,
            (
                [p.split("-") for p in args.lang_pairs]
                if training
                else [(args.source_lang, args.target_lang)]
            ),
        )

        # load dictionaries
        if training:
            extra_lang_pairs = (
                list(
                    {p for _, v in args.extra_lang_pairs.items() for p in v.split(",")}
                )
                if args.extra_lang_pairs
                else []
            )
            langs_to_load_dicts = sorted(
                {x for p in args.lang_pairs + extra_lang_pairs for x in p.split("-")}
            )
        else:
            langs_to_load_dicts = sorted([args.source_lang, args.target_lang])

        dicts = OrderedDict()
        paths = utils.split_paths(args.data)
        assert len(paths) > 0
        for lang in langs_to_load_dicts:
            dicts[lang] = load_dictionary(
                os.path.join(paths[0], "dict.{}.txt".format(lang))
            )
            augment_dictionary(
                dictionary=dicts[lang],
                language_list=language_list,
                lang_tok_style=args.lang_tok_style,
                langtoks_specs=args.langtoks_specs,
                extra_data=args.extra_data,
            )
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[langs_to_load_dicts[0]].pad()
                assert dicts[lang].eos() == dicts[langs_to_load_dicts[0]].eos()
                assert dicts[lang].unk() == dicts[langs_to_load_dicts[0]].unk()
            logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang])))
        return language_list, dicts, training
Example #21
0
def _main(cfg: DictConfig, output_file):
    logging.basicConfig(
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=os.environ.get("LOGLEVEL", "INFO").upper(),
        stream=output_file,
    )
    logger = logging.getLogger("fairseq_cli.generate")

    utils.import_user_module(cfg.common)

    if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
        cfg.dataset.max_tokens = 12000
    logger.info(cfg)

    # Fix seed for stochastic decoding
    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
        np.random.seed(cfg.common.seed)
        utils.set_torch_seed(cfg.common.seed)

    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    # Load dataset splits
    task = tasks.setup_task(cfg.task)

    # Set dictionaries
    try:
        src_dict = getattr(task, "source_dictionary", None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    overrides = ast.literal_eval(cfg.common_eval.model_overrides)

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, saved_cfg = checkpoint_utils.load_model_ensemble(
        utils.split_paths(cfg.common_eval.path),
        arg_overrides=overrides,
        task=task,
        suffix=cfg.checkpoint.checkpoint_suffix,
        strict=(cfg.checkpoint.checkpoint_shard_count == 1),
        num_shards=cfg.checkpoint.checkpoint_shard_count,
    )

    # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
    task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)

    if cfg.generation.lm_path is not None:
        overrides["data"] = cfg.task.data

        try:
            lms, _ = checkpoint_utils.load_model_ensemble(
                [cfg.generation.lm_path], arg_overrides=overrides, task=None)
        except:
            logger.warning(
                f"Failed to load language model! Please make sure that the language model dict is the same "
                f"as target dict and is located in the data dir ({cfg.task.data})"
            )
            raise

        assert len(lms) == 1
    else:
        lms = [None]

    # Optimize ensemble for generation
    for model in chain(models, lms):
        if model is None:
            continue
        if cfg.common.fp16:
            model.half()
        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
            model.cuda()
        model.prepare_for_inference_(cfg)

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(cfg.generation.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(cfg.dataset.gen_subset),
        max_tokens=cfg.dataset.max_tokens,
        max_sentences=cfg.dataset.batch_size,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[m.max_positions() for m in models]),
        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
        seed=cfg.common.seed,
        num_shards=cfg.distributed_training.distributed_world_size,
        shard_id=cfg.distributed_training.distributed_rank,
        num_workers=cfg.dataset.num_workers,
        data_buffer_size=cfg.dataset.data_buffer_size,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()

    extra_gen_cls_kwargs = {
        "lm_model": lms[0],
        "lm_weight": cfg.generation.lm_weight
    }
    generator = task.build_generator(models,
                                     cfg.generation,
                                     extra_gen_cls_kwargs=extra_gen_cls_kwargs)

    # Handle tokenization and BPE
    tokenizer = task.build_tokenizer(cfg.tokenizer)
    bpe = task.build_bpe(cfg.bpe)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    scorer = scoring.build_scorer(cfg.scoring, tgt_dict)

    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if "net_input" not in sample:
            continue

        prefix_tokens = None
        if cfg.generation.prefix_size > 0:
            prefix_tokens = sample["target"][:, :cfg.generation.prefix_size]

        constraints = None
        if "constraints" in sample:
            constraints = sample["constraints"]

        gen_timer.start()
        hypos = task.inference_step(
            generator,
            models,
            sample,
            prefix_tokens=prefix_tokens,
            constraints=constraints,
        )
        num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample["id"].tolist()):
            has_target = sample["target"] is not None

            # Remove padding
            if "src_tokens" in sample["net_input"]:
                src_tokens = utils.strip_pad(
                    sample["net_input"]["src_tokens"][i, :], tgt_dict.pad())
            else:
                src_tokens = None

            target_tokens = None
            if has_target:
                target_tokens = (utils.strip_pad(sample["target"][i, :],
                                                 tgt_dict.pad()).int().cpu())

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    cfg.dataset.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    cfg.dataset.gen_subset).tgt.get_original_text(sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens,
                                              cfg.common_eval.post_process)
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(
                        target_tokens,
                        cfg.common_eval.post_process,
                        escape_unk=True,
                        extra_symbols_to_ignore=
                        get_symbols_to_strip_from_output(generator),
                    )

            src_str = decode_fn(src_str)
            if has_target:
                target_str = decode_fn(target_str)

            if not cfg.common_eval.quiet:
                if src_dict is not None:
                    print("S-{}\t{}".format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print("T-{}\t{}".format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:cfg.generation.nbest]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo["tokens"].int().cpu(),
                    src_str=src_str,
                    alignment=hypo["alignment"],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=cfg.common_eval.post_process,
                    extra_symbols_to_ignore=get_symbols_to_strip_from_output(
                        generator),
                )
                detok_hypo_str = decode_fn(hypo_str)
                if not cfg.common_eval.quiet:
                    score = hypo["score"] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print(
                        "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
                        file=output_file,
                    )
                    # detokenized hypothesis
                    print(
                        "D-{}\t{}\t{}".format(sample_id, score,
                                              detok_hypo_str),
                        file=output_file,
                    )
                    print(
                        "P-{}\t{}".format(
                            sample_id,
                            " ".join(
                                map(
                                    lambda x: "{:.4f}".format(x),
                                    # convert from base e to base 2
                                    hypo["positional_scores"].div_(math.log(2)
                                                                   ).tolist(),
                                )),
                        ),
                        file=output_file,
                    )

                    if cfg.generation.print_alignment == "hard":
                        print(
                            "A-{}\t{}".format(
                                sample_id,
                                " ".join([
                                    "{}-{}".format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ]),
                            ),
                            file=output_file,
                        )
                    if cfg.generation.print_alignment == "soft":
                        print(
                            "A-{}\t{}".format(
                                sample_id,
                                " ".join([
                                    ",".join(src_probs)
                                    for src_probs in alignment
                                ]),
                            ),
                            file=output_file,
                        )

                    if cfg.generation.print_step:
                        print(
                            "I-{}\t{}".format(sample_id, hypo["steps"]),
                            file=output_file,
                        )

                    if cfg.generation.retain_iter_history:
                        for step, h in enumerate(hypo["history"]):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h["tokens"].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print(
                                "E-{}_{}\t{}".format(sample_id, step, h_str),
                                file=output_file,
                            )

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or cfg.common_eval.post_process is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True)
                    if hasattr(scorer, "add_string"):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({"wps": round(wps_meter.avg)})
        num_sentences += (sample["nsentences"]
                          if "nsentences" in sample else sample["id"].numel())

    logger.info("NOTE: hypothesis and token scores are output in base 2")
    logger.info(
        "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)"
        .format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        ))
    if has_target:
        if cfg.bpe and not cfg.generation.sacrebleu:
            if cfg.common_eval.post_process:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        # use print to be consistent with other main outputs: S-, H-, T-, D- and so on
        print(
            "Generate {} with beam={}: {}".format(cfg.dataset.gen_subset,
                                                  cfg.generation.beam,
                                                  scorer.result_string()),
            file=output_file,
        )

    return scorer
Example #22
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.source_dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError("Dataset not found: {} ({})".format(
                split, split_path))

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample - 1,  # one less for <s>
            pad=self.source_dictionary.pad(),
            eos=self.source_dictionary.eos(),
            break_mode=self.args.sample_break_mode,
        )
        logger.info("loaded {} blocks from: {}".format(len(dataset),
                                                       split_path))

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())

        # create masked input and targets
        mask_whole_words = (get_whole_word_mask(self.args,
                                                self.source_dictionary)
                            if self.args.mask_whole_words else None)

        src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
            dataset,
            self.source_dictionary,
            pad_idx=self.source_dictionary.pad(),
            mask_idx=self.mask_idx,
            seed=self.args.seed,
            mask_prob=self.args.mask_prob,
            leave_unmasked_prob=self.args.leave_unmasked_prob,
            random_token_prob=self.args.random_token_prob,
            freq_weighted_replacement=self.args.freq_weighted_replacement,
            mask_whole_words=mask_whole_words,
            mask_multiple_length=self.args.mask_multiple_length,
            mask_stdev=self.args.mask_stdev,
        )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(src_dataset))

        self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    "id":
                    IdDataset(),
                    "net_input": {
                        "src_tokens":
                        RightPadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        "src_lengths":
                        NumelDataset(src_dataset, reduce=False),
                    },
                    "target":
                    RightPadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    "nsentences":
                    NumSamplesDataset(),
                    "ntokens":
                    NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )
Example #23
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('espresso.speech_recognize')
    if output_file is not sys.stdout:  # also print to stdout
        logger.addHandler(logging.StreamHandler(sys.stdout))

    print_options_meaning_changes(args, logger)

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset split
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionary
    dictionary = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )
    for i, m in enumerate(models):
        if hasattr(m, 'is_wordlm') and m.is_wordlm:
            # assume subword LM comes before word LM
            if isinstance(models[i - 1], FairseqLanguageModel):
                models[i - 1] = MultiLevelLanguageModel(
                    m,
                    models[i - 1],
                    subwordlm_weight=args.subwordlm_weight,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                del models[i]
                logger.info('LM fusion with Multi-level LM')
            else:
                models[i] = TensorizedLookaheadLanguageModel(
                    m,
                    dictionary,
                    oov_penalty=args.oov_penalty,
                    open_vocab=not args.disable_open_vocab,
                )
                logger.info('LM fusion with Look-ahead Word LM')
        # assume subword LM comes after E2E models
        elif i == len(models) - 1 and isinstance(m, FairseqLanguageModel):
            logger.info('LM fusion with Subword LM')
    if args.lm_weight != 0.0:
        logger.info('using LM fusion with lm-weight={:.2f}'.format(
            args.lm_weight))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[
                model.max_positions() if hasattr(model, 'encoder') else
                (None, model.max_positions()) for model in models
            ]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Initialize generator
    if args.match_source_len:
        logger.warning(
            'The option match_source_len is not applicable to speech recognition. Ignoring it.'
        )
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Generate and compute WER
    scorer = wer.Scorer(dictionary, wer_output_filter=args.wer_output_filter)
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample['target'][:, :args.prefix_size]

        gen_timer.start()
        hypos = task.inference_step(
            generator,
            models,
            sample,
            prefix_tokens,
            lm_weight=args.lm_weight,
        )
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        # obtain nonpad mask of encoder output to plot attentions
        if args.print_alignment:
            net_input = sample['net_input']
            src_tokens = net_input['src_tokens']
            output_lengths = models[0].encoder.output_lengths(
                net_input['src_lengths'])
            nonpad_idxs = sequence_mask(
                output_lengths,
                models[0].encoder.output_lengths(src_tokens.size(1)))

        for i in range(len(sample['id'])):
            has_target = sample['target'] is not None
            utt_id = sample['utt_id'][i]

            # Retrieve the original sentences
            if has_target:
                target_str = sample['target_raw_text'][i]
                if not args.quiet:
                    detok_target_str = decode_fn(target_str)
                    print('T-{}\t{}'.format(utt_id, detok_target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args.nbest]):
                hypo_str = dictionary.string(
                    hypo['tokens'].int().cpu(),
                    bpe_symbol=None,
                    extra_symbols_to_ignore={dictionary.pad()},
                )  # not removing bpe at this point
                detok_hypo_str = decode_fn(hypo_str)
                if not args.quiet:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    print('H-{}\t{}\t{}'.format(utt_id, detok_hypo_str, score),
                          file=output_file)

                # Score and obtain attention only the top hypothesis
                if j == 0:
                    # src_len x tgt_len
                    attention = hypo['attention'][nonpad_idxs[i]].float().cpu() \
                        if args.print_alignment and hypo['attention'] is not None else None
                    if args.print_alignment and attention is not None:
                        save_dir = os.path.join(args.results_path,
                                                'attn_plots')
                        os.makedirs(save_dir, exist_ok=True)
                        plot_attention(attention, detok_hypo_str, utt_id,
                                       save_dir)
                    scorer.add_prediction(utt_id, hypo_str)
                    if has_target:
                        scorer.add_evaluation(utt_id, target_str, hypo_str)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Recognized {} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if args.print_alignment:
        logger.info('Saved attention plots in ' + save_dir)

    if has_target:
        scorer.add_ordered_utt_list(task.datasets[args.gen_subset].tgt.utt_ids)

    fn = 'decoded_char_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_char_results())
        logger.info('Decoded char results saved as ' + f.name)

    fn = 'decoded_results.txt'
    with open(os.path.join(args.results_path, fn), 'w', encoding='utf-8') as f:
        f.write(scorer.print_results())
        logger.info('Decoded results saved as ' + f.name)

    if has_target:
        header = 'Recognize {} with beam={}: '.format(args.gen_subset,
                                                      args.beam)
        fn = 'wer'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            res = 'WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.wer()))
            logger.info(header + res)
            f.write(res + '\n')
            logger.info('WER saved in ' + f.name)

        fn = 'cer'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            res = 'CER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%'.format(
                *(scorer.cer()))
            logger.info(' ' * len(header) + res)
            f.write(res + '\n')
            logger.info('CER saved in ' + f.name)

        fn = 'aligned_results.txt'
        with open(os.path.join(args.results_path, fn), 'w',
                  encoding='utf-8') as f:
            f.write(scorer.print_aligned_results())
            logger.info('Aligned results saved as ' + f.name)
    return scorer
Example #24
0
 def setup_task(cls, args, **kwargs):
     paths = utils.split_paths(args.data)
     assert len(paths) > 0
     dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
     logger.info("dictionary: {} types".format(len(dictionary)))
     return cls(args, dictionary)
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        languages = sorted(name for name in os.listdir(data_path)
                           if os.path.isdir(os.path.join(data_path, name)))

        logger.info("Training on {0} languages: {1}".format(
            len(languages), languages))
        logger.info("Language to id mapping: ",
                    {lang: id
                     for id, lang in enumerate(languages)})

        mask_whole_words = self._get_whole_word_mask()
        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)

            dataset = data_utils.load_indexed_dataset(
                split_path,
                self.source_dictionary,
                self.args.dataset_impl,
                combine=combine,
            )
            if dataset is None:
                raise FileNotFoundError('Dataset not found: {} ({})'.format(
                    split, split_path))

            # create continuous blocks of tokens
            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                self.args.tokens_per_sample - 1,  # one less for <s>
                pad=self.source_dictionary.pad(),
                eos=self.source_dictionary.eos(),
                break_mode=self.args.sample_break_mode,
            )
            logger.info('loaded {} blocks from: {}'.format(
                len(dataset), split_path))

            # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
            dataset = PrependTokenDataset(dataset,
                                          self.source_dictionary.bos())

            src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
                dataset,
                self.source_dictionary,
                pad_idx=self.source_dictionary.pad(),
                mask_idx=self.mask_idx,
                seed=self.args.seed,
                mask_prob=self.args.mask_prob,
                leave_unmasked_prob=self.args.leave_unmasked_prob,
                random_token_prob=self.args.random_token_prob,
                freq_weighted_replacement=self.args.freq_weighted_replacement,
                mask_whole_words=mask_whole_words,
            )

            lang_dataset = NestedDictionaryDataset(
                {
                    'net_input': {
                        'src_tokens':
                        PadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                            left_pad=False,
                        ),
                        'src_lengths':
                        NumelDataset(src_dataset, reduce=False),
                    },
                    'target':
                    PadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                        left_pad=False,
                    ),
                    'nsentences':
                    NumSamplesDataset(),
                    'ntokens':
                    NumelDataset(src_dataset, reduce=True),
                    'lang_id':
                    RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
                },
                sizes=[src_dataset.sizes],
            )
            lang_datasets.append(lang_dataset)

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info('loaded total {} blocks for all languages'.format(
            dataset_lengths.sum(), ))
        if split == self.args.train_subset:
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(dataset_lengths)
            logger.info(
                "Sample probability by language: ", {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                })
            size_ratio = (sample_probs *
                          dataset_lengths.sum()) / dataset_lengths
            logger.info(
                "Up/Down Sampling ratio by language: ", {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                })

            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] >= 1.0,
                ) for i, d in enumerate(lang_datasets)
            ]
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + '_' + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ','.join(lang_splits))

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )
Example #26
0
    def __init__(self):
        parser = options.get_interactive_generation_parser()
        args = options.parse_args_and_arch(parser)
        cfg = convert_namespace_to_omegaconf(args)
        utils.import_user_module(cfg.common)

        if cfg.interactive.buffer_size < 1:
            cfg.interactive.buffer_size = 1
        if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
            cfg.dataset.batch_size = 1

        assert (not cfg.generation.sampling
                or cfg.generation.nbest == cfg.generation.beam
                ), "--sampling requires --nbest to be equal to --beam"
        assert (not cfg.dataset.batch_size
                or cfg.dataset.batch_size <= cfg.interactive.buffer_size
                ), "--batch-size cannot be larger than --buffer-size"

        use_cuda = torch.cuda.is_available() and not cfg.common.cpu

        # Setup task, e.g., translation
        task = tasks.setup_task(cfg.task)

        # Load ensemble
        models, _model_args = checkpoint_utils.load_model_ensemble(
            utils.split_paths(cfg.common_eval.path),
            task=task,
            suffix=cfg.checkpoint.checkpoint_suffix,
            strict=(cfg.checkpoint.checkpoint_shard_count == 1),
            num_shards=cfg.checkpoint.checkpoint_shard_count,
        )

        # Set dictionaries
        src_dict = task.source_dictionary
        tgt_dict = task.target_dictionary

        # Optimize ensemble for generation
        for model in models:
            if model is None:
                continue
            if cfg.common.fp16:
                model.half()
            if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
                model.cuda()
            model.prepare_for_inference_(cfg)

        # Initialize generator
        generator = task.build_generator(models, cfg.generation)

        # Handle tokenization and BPE
        tokenizer = encoders.build_tokenizer(cfg.tokenizer)
        bpe = encoders.build_bpe(cfg.bpe)

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        align_dict = utils.load_align_dict(cfg.generation.replace_unk)

        max_positions = utils.resolve_max_positions(
            task.max_positions(), *[model.max_positions() for model in models])
        if cfg.interactive.buffer_size > 1:
            logger.info("Sentence buffer size: %s",
                        cfg.interactive.buffer_size)

        self.context = {
            'bpe': bpe,
            'tokenizer': tokenizer,
            'cfg': cfg,
            'task': task,
            'max_positions': max_positions,
            'use_cuda': use_cuda,
            'generator': generator,
            'models': models,
            'src_dict': src_dict,
            'tgt_dict': tgt_dict,
            'align_dict': align_dict,
        }
Example #27
0
    def load_dataset(self, split, epoch=1, combine=False):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        loaded_datasets = []

        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        logger.info("data_path", data_path)

        for k in itertools.count():
            split_k = split + (str(k) if k > 0 else '')
            path = os.path.join(data_path, split_k)
            ds = indexed_dataset.make_dataset(
                path,
                impl=self.args.dataset_impl,
                fix_lua_indexing=True,
                dictionary=self.dictionary,
            )

            if ds is None:
                if k > 0:
                    break
                else:
                    raise FileNotFoundError(
                        'Dataset not found: {} ({})'.format(split, data_path))

            with data_utils.numpy_seed(self.seed + k):
                loaded_datasets.append(
                    BlockPairDataset(
                        ds,
                        self.dictionary,
                        ds.sizes,
                        self.args.tokens_per_sample,
                        break_mode=self.args.break_mode,
                        doc_break_size=1,
                    ))

            logger.info('{} {} {} examples都是非常重要的例子'.format(
                data_path, split_k, len(loaded_datasets[-1])))

            if not combine:
                break

        if len(loaded_datasets) == 1:
            dataset = loaded_datasets[0]
            sizes = dataset.sizes
        else:
            dataset = ConcatDataset(loaded_datasets)
            sizes = np.concatenate([ds.sizes for ds in loaded_datasets])

        self.datasets[split] = MaskedLMDataset(
            dataset=dataset,
            sizes=sizes,
            vocab=self.dictionary,
            pad_idx=self.dictionary.pad(),
            mask_idx=self.dictionary.mask(),
            classif_token_idx=self.dictionary.cls(),
            sep_token_idx=self.dictionary.sep(),
            shuffle=self.args.shuffle_dataset,
            seed=self.seed,
        )
Example #28
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=('tqdm' if not args.no_progress_bar else 'none'),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    wps_meter = TimeMeter()
    for sample in progress:
        sample = utils.move_to_cuda(sample) if use_cuda else sample
        if 'net_input' not in sample:
            continue

        prefix_tokens = None
        if args.prefix_size > 0:
            prefix_tokens = sample['target'][:, :args.prefix_size]

        gen_timer.start()
        hypos = task.inference_step(generator, models, sample, prefix_tokens)
        num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
        gen_timer.stop(num_generated_tokens)

        for i, sample_id in enumerate(sample['id'].tolist()):
            has_target = sample['target'] is not None

            # Remove padding
            src_tokens = utils.strip_pad(
                sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
            target_tokens = None
            if has_target:
                target_tokens = utils.strip_pad(sample['target'][i, :],
                                                tgt_dict.pad()).int().cpu()

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(
                    args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(
                    args.gen_subset).tgt.get_original_text(sample_id)
            else:
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                else:
                    src_str = ""
                if has_target:
                    target_str = tgt_dict.string(target_tokens,
                                                 args.remove_bpe,
                                                 escape_unk=True,
                                                 extra_symbols_to_ignore={
                                                     generator.eos,
                                                 })

            src_str = decode_fn(src_str)
            if has_target:
                target_str = decode_fn(target_str)

            if not args.quiet:
                if src_dict is not None:
                    print('S-{}\t{}'.format(sample_id, src_str),
                          file=output_file)
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str),
                          file=output_file)

            # Process top predictions
            for j, hypo in enumerate(hypos[i][:args.nbest]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'],
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                    extra_symbols_to_ignore={
                        generator.eos,
                    })
                detok_hypo_str = decode_fn(hypo_str)
                if not args.quiet:
                    score = hypo['score'] / math.log(2)  # convert to base 2
                    # original hypothesis (after tokenization and BPE)
                    print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str),
                          file=output_file)
                    # detokenized hypothesis
                    print('D-{}\t{}\t{}'.format(sample_id, score,
                                                detok_hypo_str),
                          file=output_file)
                    print(
                        'P-{}\t{}'.format(
                            sample_id,
                            ' '.join(
                                map(
                                    lambda x: '{:.4f}'.format(x),
                                    # convert from base e to base 2
                                    hypo['positional_scores'].div_(math.log(2)
                                                                   ).tolist(),
                                ))),
                        file=output_file)

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id, ' '.join([
                                '{}-{}'.format(src_idx, tgt_idx)
                                for src_idx, tgt_idx in alignment
                            ])),
                              file=output_file)

                    if args.print_step:
                        print('I-{}\t{}'.format(sample_id, hypo['steps']),
                              file=output_file)

                    if 'enc_selection' in hypo:
                        print('Menc-{}\t{}'.format(sample_id,
                                                   hypo['enc_selection']),
                              file=output_file)
                    if 'dec_selection' in hypo:
                        print('Mdec-{}\t{}'.format(sample_id,
                                                   hypo['dec_selection']),
                              file=output_file)
                    if args.print_attn_confidence:
                        print('C-{}\t{}'.format(sample_id,
                                                hypo['enc_self_attn_conf']),
                              file=output_file)

                    if getattr(args, 'retain_iter_history', False):
                        for step, h in enumerate(hypo['history']):
                            _, h_str, _ = utils.post_process_prediction(
                                hypo_tokens=h['tokens'].int().cpu(),
                                src_str=src_str,
                                alignment=None,
                                align_dict=None,
                                tgt_dict=tgt_dict,
                                remove_bpe=None,
                            )
                            print('E-{}_{}\t{}'.format(sample_id, step, h_str),
                                  file=output_file)

                # Score only the top hypothesis
                if has_target and j == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tgt_dict.encode_line(
                            target_str, add_if_not_exist=True)
                        hypo_tokens = tgt_dict.encode_line(
                            detok_hypo_str, add_if_not_exist=True)
                    if hasattr(scorer, 'add_string'):
                        scorer.add_string(target_str, detok_hypo_str)
                    else:
                        scorer.add(target_tokens, hypo_tokens)

        wps_meter.update(num_generated_tokens)
        progress.log({'wps': round(wps_meter.avg)})
        num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        if args.bpe and not args.sacrebleu:
            if args.remove_bpe:
                logger.warning(
                    "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
                )
            else:
                logger.warning(
                    "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words.  Use --sacrebleu for standard 13a BLEU tokenization"
                )
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))

    return scorer
Example #29
0
    def load_dataset(self,
                     split: str,
                     epoch=1,
                     combine=False,
                     **kwargs) -> MonolingualDataset:
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, valid1, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0

        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        # each process has its own copy of the raw data (likely to be an np.memmap)
        dataset = data_utils.load_indexed_dataset(split_path,
                                                  self.dictionary,
                                                  self.args.dataset_impl,
                                                  combine=combine)
        if dataset is None:
            raise FileNotFoundError(
                f"Dataset not found: {split} ({split_path})")

        dataset = maybe_shorten_dataset(
            dataset,
            split,
            self.args.shorten_data_split_list,
            self.args.shorten_method,
            self.args.tokens_per_sample,
            self.args.seed,
        )
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample,
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            include_targets=True,
            use_plasma_view=self.args.use_plasma_view,
            split_path=split_path,
            plasma_path=self.args.plasma_path,
        )

        add_eos_for_other_targets = (self.args.sample_break_mode is not None
                                     and self.args.sample_break_mode != "none")
        fixed_pad_length = None
        if self.args.pad_to_fixed_length:
            fixed_pad_length = self.args.tokens_per_sample

        pad_to_bsz = None
        if self.args.pad_to_fixed_bsz:
            pad_to_bsz = (self.args.batch_size_valid
                          if "valid" in split else self.args.batch_size)

        self.datasets[split] = MonolingualDataset(
            dataset=dataset,
            sizes=dataset.sizes,
            src_vocab=self.dictionary,
            tgt_vocab=self.output_dictionary,
            add_eos_for_other_targets=add_eos_for_other_targets,
            shuffle=True,
            targets=self.targets,
            add_bos_token=self.args.add_bos_token,
            fixed_pad_length=fixed_pad_length,
            pad_to_bsz=pad_to_bsz,
        )
Example #30
0
def _main(args, output_file):
    logging.basicConfig(
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger("espresso.dump_posteriors")

    print_options_meaning_changes(args, logger)

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset split
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Load ensemble
    logger.info("loading model(s) from {}".format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        utils.split_paths(args.path),
        arg_overrides=eval(args.model_overrides),
        task=task,
        suffix=getattr(args, "checkpoint_suffix", ""),
    )

    # Load state prior for cross-entropy trained systems decoding
    if args.state_prior_file is not None:
        prior = torch.from_numpy(kaldi_io.read_vec_flt(args.state_prior_file))
    else:
        prior = []

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()
        if isinstance(prior, list) and getattr(model, "state_prior",
                                               None) is not None:
            prior.append(model.state_prior.unsqueeze(0))

    if isinstance(prior, list) and len(prior) > 0:
        prior = torch.cat(prior, 0).mean(0)  # average priors across models
        prior = prior / prior.sum()  # re-normalize
    elif isinstance(prior, list):
        prior = None

    if prior is not None:
        if args.fp16:
            prior = prior.half()
        if use_cuda:
            prior = prior.cuda()
        log_prior = prior.log()
    else:
        log_prior = None

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(), *[
                model.max_positions() if hasattr(model, "encoder") else
                (None, model.max_positions()) for model in models
            ]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        default_log_format=("tqdm" if not args.no_progress_bar else "none"),
    )

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(models, args)

    # Generate and dump
    num_sentences = 0
    chunk_width = getattr(task, "chunk_width", None)
    lprobs_wspecifier = "ark:| copy-matrix ark:- ark:-"
    with kaldi_io.open_or_fd(lprobs_wspecifier, "wb") as f:
        if chunk_width is None:  # normal dumping (i.e., no chunking)
            for sample in progress:
                sample = utils.move_to_cuda(sample) if use_cuda else sample
                if "net_input" not in sample:
                    continue

                gen_timer.start()
                lprobs, padding_mask = task.inference_step(
                    generator, models, sample)
                if log_prior is not None:
                    assert lprobs.size(-1) == log_prior.size(0)
                    lprobs = lprobs - log_prior
                out_lengths = (~padding_mask).long().sum(
                    dim=1).cpu() if padding_mask is not None else None
                num_processed_frames = sample["ntokens"]
                gen_timer.stop(num_processed_frames)
                num_sentences += sample["nsentences"]

                if out_lengths is not None:
                    for i in range(sample["nsentences"]):
                        length = out_lengths[i]
                        kaldi_io.write_mat(f,
                                           lprobs[i, :length, :].cpu().numpy(),
                                           key=sample["utt_id"][i])
                else:
                    for i in range(sample["nsentences"]):
                        kaldi_io.write_mat(f,
                                           lprobs[i, :, :].cpu().numpy(),
                                           key=sample["utt_id"][i])
        else:  # dumping chunks within the same utterance from left to right
            for sample in progress:  # sample is actually a list of batches
                sample = utils.move_to_cuda(sample) if use_cuda else sample
                utt_id = sample[0]["utt_id"]
                id = sample[0]["id"]
                whole_lprobs = None
                for i, chunk_sample in enumerate(sample):
                    if "net_input" not in chunk_sample:
                        continue

                    assert chunk_sample["utt_id"] == utt_id and (
                        chunk_sample["id"] == id).all()
                    gen_timer.start()
                    lprobs, _ = task.inference_step(generator, models,
                                                    chunk_sample)
                    if log_prior is not None:
                        assert lprobs.size(-1) == log_prior.size(0)
                        lprobs = lprobs - log_prior
                    if whole_lprobs is None:
                        whole_lprobs = lprobs.cpu()
                    else:
                        whole_lprobs = torch.cat((whole_lprobs, lprobs.cpu()),
                                                 1)
                    num_processed_frames = chunk_sample["ntokens"]
                    gen_timer.stop(num_processed_frames)

                    if i == len(sample) - 1:
                        num_sentences += len(utt_id)
                        for j in range(len(utt_id)):
                            truncated_length = models[0].output_lengths(
                                task.dataset(args.gen_subset).src_sizes[id[j]]
                            )  # length is after possible subsampling by the model
                            mat = whole_lprobs[j, :truncated_length, :]
                            kaldi_io.write_mat(f, mat.numpy(), key=utt_id[j])

    logger.info(
        "Dumped {} utterances ({} frames) in {:.1f}s ({:.2f} sentences/s, {:.2f} frames/s)"
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))

    return