Esempio n. 1
0
def get_asr_dataset_from_json(
    data_path,
    split,
    dictionary,
    combine,
    upsample_primary,
    num_buckets=0,
    shuffle=True,
    lf_mmi=True,
    seed=1,
    specaugment_config=None,
    chunk_width=None,
    chunk_left_context=None,
    chunk_right_context=None,
    label_delay=0,
):
    """
    Parse data json and create dataset.
    See espresso/tools/asr_prep_json.py which pack json from raw files
    Json example:
    {
        "011c0202": {
            "feat": "data/train_si284_spe2e_hires/data/raw_mfcc_train_si284_spe2e_hires.1.ark:24847",
            "numerator_fst": "exp/chain/e2e_bichar_tree_tied1a/fst.1.ark:6704",
            "alignment": "exp/tri3/ali.ark:8769",
            "text": "THE HOTELi OPERATOR'S EMBASSY",
            "utt2num_frames": "693",
        },
        "011c0203": {
            ...
        }
    }
    """
    src_datasets = []
    tgt_datasets = []
    text_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")
        data_json_path = os.path.join(data_path, "{}.json".format(split_k))
        if not os.path.isfile(data_json_path):
            if k > 0:
                break
            else:
                raise FileNotFoundError(
                    "Dataset not found: {}".format(data_json_path))

        with open(data_json_path, "rb") as f:
            loaded_json = json.load(f, object_pairs_hook=OrderedDict)

        utt_ids, feats, numerator_fsts, alignments, text, utt2num_frames = [], [], [], [], [], []
        for utt_id, val in loaded_json.items():
            utt_ids.append(utt_id)
            feats.append(val["feat"])
            if "numerator_fst" in val:
                numerator_fsts.append(val["numerator_fst"])
            if "alignment" in val:
                alignments.append(val["alignment"])
            if "text" in val:
                text.append(val["text"])
            if "utt2num_frames" in val:
                utt2num_frames.append(int(val["utt2num_frames"]))

        assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames)
        src_datasets.append(
            FeatScpCachedDataset(
                utt_ids,
                feats,
                utt2num_frames=utt2num_frames,
                seed=seed,
                specaugment_config=specaugment_config
                if split == "train" else None,
                ordered_prefetch=True,
            ))
        if lf_mmi:
            if len(numerator_fsts) > 0:
                assert len(utt_ids) == len(numerator_fsts)
                tgt_datasets.append(
                    NumeratorGraphDataset(utt_ids, numerator_fsts))
        else:  # cross-entropy
            if len(alignments) > 0:
                assert len(utt_ids) == len(alignments)
                tgt_datasets.append(
                    AliScpCachedDataset(utt_ids,
                                        alignments,
                                        utt2num_frames=utt2num_frames,
                                        ordered_prefetch=True))

        if len(text) > 0:
            assert len(utt_ids) == len(text)
            text_datasets.append(
                AsrTextDataset(utt_ids, text, dictionary, append_eos=False))

        logger.info("{} {} examples".format(data_json_path,
                                            len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
    assert len(src_datasets) == len(text_datasets) or len(text_datasets) == 0

    feat_dim = src_datasets[0].feat_dim

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
        text_dataset = text_datasets[0] if len(text_datasets) > 0 else None
    else:
        for i in range(1, len(src_datasets)):
            assert feat_dim == src_datasets[i].feat_dim, \
                "feature dimension does not match across multiple json files"
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None
        if len(text_datasets) > 0:
            text_dataset = ConcatDataset(text_datasets, sample_ratios)
        else:
            text_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    if lf_mmi:
        return AsrChainDataset(
            src_dataset,
            src_dataset.sizes,
            tgt_dataset,
            tgt_dataset_sizes,
            text=text_dataset,
            num_buckets=num_buckets,
            shuffle=shuffle,
        )
    else:
        return AsrXentDataset(
            src_dataset,
            src_dataset.sizes,
            tgt_dataset,
            tgt_dataset_sizes,
            text=text_dataset,
            num_buckets=num_buckets,
            shuffle=shuffle,
            seed=seed,
            chunk_width=chunk_width,
            chunk_left_context=chunk_left_context,
            chunk_right_context=chunk_right_context,
            label_delay=label_delay,
            random_chunking=(split == "train" and chunk_width is not None),
        )
Esempio n. 2
0
 def build_dataset_for_inference(self, src_tokens, src_lengths):
     return AsrChainDataset(src_tokens, src_lengths)
def get_asr_dataset_from_json(
    data_path,
    split,
    dictionary,
    combine,
    upsample_primary=1,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    lf_mmi=True,
    seed=1,
    global_cmvn_stats_path=None,
    specaugment_config=None,
    chunk_width=None,
    chunk_left_context=None,
    chunk_right_context=None,
    label_delay=0,
):
    """
    Parse data json and create dataset.
    See espresso/tools/asr_prep_json.py which pack json from raw files
    Json example:
    {
        "011c0202": {
            "feat": "data/train_si284_spe2e_hires/data/raw_mfcc_train_si284_spe2e_hires.1.ark:24847" or
            "wave": "/export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1" or
            "command": "sph2pipe -f wav /export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1 |",
            "numerator_fst": "exp/chain/e2e_bichar_tree_tied1a/fst.1.ark:6704",
            "alignment": "exp/tri3/ali.ark:8769",
            "text": "THE HOTELi OPERATOR'S EMBASSY",
            "utt2num_frames": "693",
        },
        "011c0203": {
            ...
        }
    }
    """
    src_datasets = []
    tgt_datasets = []
    text_datasets = []

    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")
        data_json_path = os.path.join(data_path, "{}.json".format(split_k))
        if not os.path.isfile(data_json_path):
            if k > 0:
                break
            else:
                raise FileNotFoundError(
                    "Dataset not found: {}".format(data_json_path))

        with open(data_json_path, "rb") as f:
            loaded_json = json.load(f, object_pairs_hook=OrderedDict)

        utt_ids, audios, numerator_fsts, alignments, text, utt2num_frames = [], [], [], [], [], []
        for utt_id, val in loaded_json.items():
            utt_ids.append(utt_id)
            if "feat" in val:
                audio = val["feat"]
            elif "wave" in val:
                audio = val["wave"]
            elif "command" in val:
                audio = val["command"]
            else:
                raise KeyError(
                    f"'feat', 'wave' or 'command' should be present as a field for the entry {utt_id} in {data_json_path}"
                )
            audios.append(audio)
            if "numerator_fst" in val:
                numerator_fsts.append(val["numerator_fst"])
            if "alignment" in val:
                alignments.append(val["alignment"])
            if "text" in val:
                text.append(val["text"])
            if "utt2num_frames" in val:
                utt2num_frames.append(int(val["utt2num_frames"]))

        assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames)
        if "feat" in next(loaded_json.items()):
            extra_kwargs = {}
        else:
            extra_kwargs = {"feat_dim": 40, "feature_type": "mfcc"}
            if global_cmvn_stats_path is not None:
                feature_transforms_config = {
                    "transforms": ["global_cmvn"],
                    "global_cmvn": {
                        "stats_npz_path": global_cmvn_stats_path
                    }
                }
                extra_kwargs[
                    "feature_transforms_config"] = feature_transforms_config
        src_datasets.append(
            AudioFeatCachedDataset(utt_ids,
                                   audios,
                                   utt2num_frames=utt2num_frames,
                                   seed=seed,
                                   specaugment_config=specaugment_config
                                   if split == "train" else None,
                                   ordered_prefetch=True,
                                   **extra_kwargs))
        if lf_mmi:
            if len(numerator_fsts) > 0:
                assert len(utt_ids) == len(numerator_fsts)
                tgt_datasets.append(
                    NumeratorGraphDataset(utt_ids, numerator_fsts))
        else:  # cross-entropy
            if len(alignments) > 0:
                assert len(utt_ids) == len(alignments)
                tgt_datasets.append(
                    AliScpCachedDataset(
                        utt_ids,
                        alignments,
                        utt2num_frames=utt2num_frames,
                        ordered_prefetch=True,
                    ))

        if len(text) > 0:
            assert len(utt_ids) == len(text)
            text_datasets.append(
                AsrTextDataset(utt_ids, text, dictionary, append_eos=False))

        logger.info("{} {} examples".format(data_json_path,
                                            len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
    assert len(src_datasets) == len(text_datasets) or len(text_datasets) == 0

    feat_dim = src_datasets[0].feat_dim

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
        text_dataset = text_datasets[0] if len(text_datasets) > 0 else None
    else:
        for i in range(1, len(src_datasets)):
            assert (
                feat_dim == src_datasets[i].feat_dim
            ), "feature dimension does not match across multiple json files"
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None
        if len(text_datasets) > 0:
            text_dataset = ConcatDataset(text_datasets, sample_ratios)
        else:
            text_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    if lf_mmi:
        return AsrChainDataset(
            src_dataset,
            src_dataset.sizes,
            tgt_dataset,
            tgt_dataset_sizes,
            text=text_dataset,
            num_buckets=num_buckets,
            shuffle=shuffle,
            pad_to_multiple=pad_to_multiple,
        )
    else:
        return AsrXentDataset(
            src_dataset,
            src_dataset.sizes,
            tgt_dataset,
            tgt_dataset_sizes,
            text=text_dataset,
            num_buckets=num_buckets,
            shuffle=shuffle,
            pad_to_multiple=pad_to_multiple,
            seed=seed,
            chunk_width=chunk_width,
            chunk_left_context=chunk_left_context,
            chunk_right_context=chunk_right_context,
            label_delay=label_delay,
            random_chunking=(split == "train" and chunk_width is not None),
        )