コード例 #1
0
 def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None):
     return AsrDataset(
         src_tokens,
         src_lengths,
         dictionary=self.target_dictionary,
         constraints=constraints,
     )
コード例 #2
0
ファイル: test_asr_dataset.py プロジェクト: vyraun/espresso
    def _asr_dataset_helper(
        self, all_in_memory=False, ordered_prefetch=False, has_utt2num_frames=False,
    ):
        if not all_in_memory:
            src_dataset = FeatScpCachedDataset(
                utt_ids=self.feats_utt_ids,
                rxfiles=self.rxfiles,
                utt2num_frames=self.utt2num_frames if has_utt2num_frames else None,
                ordered_prefetch=ordered_prefetch,
                cache_size=self.cache_size,
            )
        else:
            src_dataset = FeatScpInMemoryDataset(
                utt_ids=self.feats_utt_ids,
                rxfiles=self.rxfiles,
                utt2num_frames=self.utt2num_frames if has_utt2num_frames else None,
            )
        tgt_dataset = AsrTextDataset(
            utt_ids=self.text_utt_ids,
            token_text=self.token_text,
            dictionary=self.dictionary,
        )

        dataset = AsrDataset(
            src_dataset, src_dataset.sizes,
            tgt_dataset, tgt_dataset.sizes, self.dictionary,
            left_pad_source=False,
            left_pad_target=False,
            max_source_positions=1000,
            max_target_positions=200,
        )

        # assume one is a subset of the other
        expected_dataset_size = min(self.num_audios, self.num_transripts)
        self.assertEqual(len(dataset.src), expected_dataset_size)
        self.assertEqual(len(dataset.tgt), expected_dataset_size)

        indices = list(range(expected_dataset_size))
        batch_sampler = []
        for i in range(0, expected_dataset_size, self.batch_size):
            batch_sampler.append(indices[i:i+self.batch_size])

        if not all_in_memory:
            dataset.prefetch(indices)

        dataloader = torch.utils.data.DataLoader(
            dataset,
            collate_fn=dataset.collater,
            batch_sampler=batch_sampler,
        )

        for i, batch in enumerate(iter(dataloader)):
            bsz = batch["nsentences"]
            self.assertEqual(bsz, len(batch_sampler[i]))
            src_frames = batch["net_input"]["src_tokens"]
            src_lengths = batch["net_input"]["src_lengths"]
            tgt_tokens = self.dictionary.string(
                batch["target"], extra_symbols_to_ignore={self.dictionary.pad()}
            ).split("\n")
            tgt_tokens = [line.split(" ") for line in tgt_tokens]
            self.assertEqual(bsz, src_frames.size(0))
            self.assertEqual(bsz, src_lengths.numel())
            self.assertEqual(bsz, len(tgt_tokens))
            for j, utt_id in enumerate(batch["utt_id"]):
                self.assertTensorEqual(
                    torch.from_numpy(self.expected_feats[utt_id]).float(),
                    src_frames[j, :src_lengths[j], :]
                )
                self.assertEqual(
                    self.expected_text[utt_id],
                    tgt_tokens[j],
                )
コード例 #3
0
 def build_dataset_for_inference(self, src_tokens, src_lengths):
     return AsrDataset(src_tokens, src_lengths)
コード例 #4
0
def get_asr_dataset_from_json(
    data_path,
    split,
    tgt_dict,
    combine,
    upsample_primary,
    max_source_positions,
    max_target_positions,
    seed=1,
    specaugment_config=None,
):
    """
    Parse data json and create dataset.
    See espresso/tools/asr_prep_json.py which pack json from raw files
    Json example:
    {
        "011c0202": {
            "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819",
            "token_text": "T H E <space> H O T E L",
            "utt2num_frames": "693",
        },
        "011c0203": {
            ...
        }
    }
    """
    src_datasets = []
    tgt_datasets = []
    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")
        data_json_path = os.path.join(data_path, "{}.json".format(split_k))
        if not os.path.isfile(data_json_path):
            if k > 0:
                break
            else:
                raise FileNotFoundError(
                    "Dataset not found: {}".format(data_json_path))

        with open(data_json_path, "rb") as f:
            loaded_json = json.load(f, object_pairs_hook=OrderedDict)

        utt_ids, feats, token_text, utt2num_frames = [], [], [], []
        for utt_id, val in loaded_json.items():
            utt_ids.append(utt_id)
            feats.append(val["feat"])
            if "token_text" in val:
                token_text.append(val["token_text"])
            if "utt2num_frames" in val:
                utt2num_frames.append(int(val["utt2num_frames"]))

        assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames)
        src_datasets.append(
            FeatScpCachedDataset(
                utt_ids,
                feats,
                utt2num_frames=utt2num_frames,
                seed=seed,
                specaugment_config=specaugment_config
                if split == "train" else None,
                ordered_prefetch=True,
            ))
        if len(token_text) > 0:
            assert len(utt_ids) == len(token_text)
            assert tgt_dict is not None
            tgt_datasets.append(AsrTextDataset(utt_ids, token_text, tgt_dict))

        logger.info("{} {} examples".format(data_json_path,
                                            len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    feat_dim = src_datasets[0].feat_dim

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        for i in range(1, len(src_datasets)):
            assert feat_dim == src_datasets[i].feat_dim, \
                "feature dimension does not match across multiple json files"
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return AsrDataset(
        src_dataset,
        src_dataset.sizes,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=False,
        left_pad_target=False,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
    )
コード例 #5
0
def get_asr_dataset_from_json(
    data_path,
    split,
    tgt_dict,
    combine,
    upsample_primary=1,
    num_buckets=0,
    shuffle=True,
    pad_to_multiple=1,
    seed=1,
    global_cmvn_stats_path=None,
    specaugment_config=None,
):
    """
    Parse data json and create dataset.
    See espresso/tools/asr_prep_json.py which pack json from raw files
    Json example:
    {
        "011c0202": {
            "feat": "fbank/raw_fbank_pitch_train_si284.1.ark:54819" or
            "wave": "/export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1" or
            "command": "sph2pipe -f wav /export/corpora5/LDC/LDC93S6B/11-1.1/wsj0/si_tr_s/011/011c0202.wv1 |",
            "text": "THE HOTEL",
            "utt2num_frames": "693",
        },
        "011c0203": {
            ...
        }
    }
    """
    src_datasets = []
    tgt_datasets = []
    for k in itertools.count():
        split_k = split + (str(k) if k > 0 else "")
        data_json_path = os.path.join(data_path, "{}.json".format(split_k))
        if not os.path.isfile(data_json_path):
            if k > 0:
                break
            else:
                raise FileNotFoundError(
                    "Dataset not found: {}".format(data_json_path)
                )

        with open(data_json_path, "rb") as f:
            loaded_json = json.load(f, object_pairs_hook=OrderedDict)

        utt_ids, audios, texts, utt2num_frames = [], [], [], []
        for utt_id, val in loaded_json.items():
            utt_ids.append(utt_id)
            if "feat" in val:
                audio = val["feat"]
            elif "wave" in val:
                audio = val["wave"]
            elif "command" in val:
                audio = val["command"]
            else:
                raise KeyError(
                    f"'feat', 'wave' or 'command' should be present as a field for the entry {utt_id} in {data_json_path}"
                )
            audios.append(audio)
            if "text" in val:
                texts.append(val["text"])
            if "utt2num_frames" in val:
                utt2num_frames.append(int(val["utt2num_frames"]))

        assert len(utt2num_frames) == 0 or len(utt_ids) == len(utt2num_frames)
        if "feat" in next(iter(loaded_json.items())):
            extra_kwargs = {}
        else:
            extra_kwargs = {"feat_dim": 80, "feature_type": "fbank"}
            if global_cmvn_stats_path is not None:
                feature_transforms_config = {
                    "transforms": ["global_cmvn"],
                    "global_cmvn": {"stats_npz_path": global_cmvn_stats_path}
                }
                extra_kwargs["feature_transforms_config"] = feature_transforms_config
        src_datasets.append(AudioFeatDataset(
            utt_ids, audios, utt2num_frames=utt2num_frames, seed=seed,
            specaugment_config=specaugment_config if split == "train" else None,
            **extra_kwargs
        ))
        if len(texts) > 0:
            assert len(utt_ids) == len(texts)
            assert tgt_dict is not None
            tgt_datasets.append(AsrTextDataset(utt_ids, texts, tgt_dict))

        logger.info("{} {} examples".format(data_json_path, len(src_datasets[-1])))

        if not combine:
            break

    assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0

    feat_dim = src_datasets[0].feat_dim

    if len(src_datasets) == 1:
        src_dataset = src_datasets[0]
        tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
    else:
        for i in range(1, len(src_datasets)):
            assert (
                feat_dim == src_datasets[i].feat_dim
            ), "feature dimension does not match across multiple json files"
        sample_ratios = [1] * len(src_datasets)
        sample_ratios[0] = upsample_primary
        src_dataset = ConcatDataset(src_datasets, sample_ratios)
        if len(tgt_datasets) > 0:
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
        else:
            tgt_dataset = None

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
    return AsrDataset(
        src_dataset,
        src_dataset.sizes,
        tgt_dataset,
        tgt_dataset_sizes,
        tgt_dict,
        left_pad_source=False,
        left_pad_target=False,
        num_buckets=num_buckets,
        shuffle=shuffle,
        pad_to_multiple=pad_to_multiple,
    )