예제 #1
0
def int2sym(transcript_file, mapping_file):
    with open(mapping_file, "r") as f:
        mapping = f.readlines()
    mapping = dict([m.strip().split(" ") for m in mapping])

    inv_mapping = {_id: phn for phn, _id in mapping.items()}

    with open(transcript_file, "r") as f:
        transcripts = f.readlines()

    transcripts = [t.strip().split(" ", 1) for t in transcripts]
    mapped_transcripts = {}
    # _mapped_transcripts = []

    # for t in transcripts:
    #     if len(t) != 2:
    #         mapped_transcripts[t] = ["---"]
    #         transcripts.remove(t)

    for _idx, tran in enumerate(transcripts):
        if len(tran) == 1:
            transcripts[_idx] = (tran, str(mapping['<UNK>']))
            logger.debug(
                f"transcript of {tran} was empty, using <UNK> instead")

    for sample_id, transcript in transcripts:
        if sample_id.endswith("-1"):
            sample_id = sample_id[:-2]
        assert sample_id not in mapped_transcripts
        mapped_transcripts[sample_id] = [inv_mapping[transcript]]
        # _mapped_transcripts.append((sample_id, inv_mapping[transcript]))

    return mapped_transcripts
예제 #2
0
def _make_frames_sequential(samples_list, main_feat, aligned_labels,
                            max_seq_len, left_context, right_context):
    # sequential data
    if not aligned_labels:
        # unaligned labels
        sample_splits, min_len = filter_by_seqlen(samples_list, max_seq_len,
                                                  left_context, right_context)
        logger.debug(
            f"Used samples {len(sample_splits)}/{len(samples_list)} " +
            f"for a max seq length of {max_seq_len} (min length was {min_len})"
        )

    elif not aligned_labels and not max_seq_len:
        # unaligned labels but no max_seq_len
        sample_splits = [
            (filename, left_context,
             len(sample_dict["features"][main_feat]) - right_context)
            for filename, sample_dict in samples_list
        ]
    else:
        # framewise sequential
        if max_seq_len:
            sample_splits = splits_by_seqlen(samples_list, max_seq_len,
                                             left_context, right_context)
        else:
            raise NotImplementedError("Framewise without max_seq_len not impl")

    max_len = 0
    min_len = sys.maxsize

    for sample_id, start_idx, end_idx in sample_splits:
        max_len = (end_idx - start_idx) \
            if (end_idx - start_idx) > max_len else max_len

        min_len = (end_idx - start_idx) \
            if (end_idx - start_idx) < min_len else min_len

    # sort sigs/labels: longest -> shortest
    sample_splits = sorted(sample_splits, key=lambda x: x[2] - x[1])

    return sample_splits, max_len, min_len
예제 #3
0
    def _check_labels_indexed_from(self, all_labels_loaded, label_name):

        max_label = max([
            all_labels_loaded[label_name][l].max()
            for l in all_labels_loaded[label_name]
        ])
        min_label = min([
            all_labels_loaded[label_name][l].min()
            for l in all_labels_loaded[label_name]
        ])
        logger.debug(f"Max label: {max_label}")
        logger.debug(f"min label: {min_label}")

        if min_label > 0:
            logger.warn(
                f"label {label_name} is appears to be indexed from {min_label} -> making it indexed from 0"
            )
            for l in all_labels_loaded[label_name]:
                all_labels_loaded[label_name][
                    l] = all_labels_loaded[label_name][l] - min_label

            max_label = max([
                all_labels_loaded[label_name][l].max()
                for l in all_labels_loaded[label_name]
            ])
            min_label = min([
                all_labels_loaded[label_name][l].min()
                for l in all_labels_loaded[label_name]
            ])
            logger.debug(f"Max label new : {max_label}")
            logger.debug(f"min label new: {min_label}")

        if self.state.label_index_from != 0:
            assert self.state.label_index_from > 0
            all_labels_loaded[label_name] = {
                filename: all_labels_loaded[label_name][filename] +
                self.state.label_index_from
                for filename in all_labels_loaded[label_name]
            }
예제 #4
0
    def is_keyword_batch(self, input_features, sensitivity, tmp_out_dir=None):
        if tmp_out_dir is None:
            tmp_out_dir = self.out_dir

        # https://stackoverflow.com/questions/15638612/calculating-mean-and-standard-deviation-of-the-data-which-does-not-fit-in-memory
        #
        # _, feat = next(iter(input_features.items()))
        # _dim = feat.shape[-1]
        #
        # n = 0
        # mean = np.zeros((_dim))
        # M2 = np.zeros((_dim))
        #
        # for sample_name, feat in tqdm(input_features.items()):
        #     # for i in range(10):
        #     for i in range(feat.shape[0]):
        #         n += 1
        #         delta = feat[i, :] - mean
        #         mean = mean + (delta / n)
        #         M2 = M2 + (delta ** 2)
        #
        # std = np.sqrt(M2 / (n - 1))
        # mean = torch.from_numpy(mean).to(dtype=torch.float32).unsqueeze(-1)
        # std = torch.from_numpy(std).to(dtype=torch.float32).unsqueeze(-1)

        # test_output = self.test_decoder()

        # plot_phns = metadata_dict is None
        plot_phns = False
        # if plot_phns:
        #     lab_dict = {"lab_mono": {
        #         "label_folder": "/mnt/data/libs/kaldi/egs/librispeech/s5/exp/tri4b_ali_dev_clean_100/",
        #         "label_opts": "ali-to-phones --per-frame=true",
        #         "lab_data_folder": "/mnt/data/libs/kaldi/egs/librispeech/s5/data/dev_clean/",
        #         "lab_graph": "/mnt/data/libs/kaldi/egs/librispeech/s5/exp/tri4b/graph_tgsmall/"
        #     }}
        #     label_index_from = 1
        #     _labels = _load_labels(lab_dict, label_index_from, max_label_length=None, phoneme_dict=self.phoneme_dict)
        #
        #     lab_dict = {"lab_mono": {
        #         "label_folder": "/mnt/data/libs/kaldi/egs/librispeech/s5/exp/tri4b_ali_dev_clean_100/",
        #         "label_opts": "ali-to-phones",
        #         "lab_data_folder": "/mnt/data/libs/kaldi/egs/librispeech/s5/data/dev_clean/",
        #         "lab_graph": "/mnt/data/libs/kaldi/egs/librispeech/s5/exp/tri4b/graph_tgsmall/"
        #     }}
        #     label_index_from = 1
        #     _labels_no_ali = _load_labels(lab_dict, label_index_from, max_label_length=None,
        #                                   phoneme_dict=self.phoneme_dict)

        vocabulary_size = 42
        vocabulary = [
            chr(c) for c in list(range(65, 65 + 58)) +
            list(range(65 + 58 + 69, 65 + 58 + 69 + 500))
        ][:vocabulary_size]
        decoder = ctcdecode.CTCBeamDecoder(vocabulary,
                                           log_probs_input=True,
                                           beam_width=1)

        all_samples_concat = None
        for sample_name, feat in tqdm(input_features.items()):
            if all_samples_concat is None:
                all_samples_concat = feat
            else:
                all_samples_concat = np.concatenate((all_samples_concat, feat))

        mean = torch.from_numpy(np.mean(
            all_samples_concat, axis=0)).to(dtype=torch.float32).unsqueeze(-1)
        std = torch.from_numpy(np.std(
            all_samples_concat, axis=0)).to(dtype=torch.float32).unsqueeze(-1)
        post_files = []

        plot_num = 0

        # len = 88

        # input_batch = []
        # sample_names = []
        # for sample_name in tqdm(input_features, desc="computing acoustic features:"):
        #     input_feature = self.preprocess_feat(input_features[sample_name])
        #     # Normalize over whole chunk instead of only over a single file, which is done by applying the kaldi cmvn
        #     _input_feature = ((input_feature - mean) / std).unsqueeze(1)
        #     if _input_feature.shape[0] < len:
        #         _zeros = torch.zeros((88, 1, 40, 11))
        #         _zeros[-_input_feature.shape[0]:, :, :, :] = _input_feature
        #         _input_feature = _zeros
        #     input_batch.append(_input_feature)
        #     sample_names.append(sample_name)

        # input_batch = {'fbank': torch.cat(input_batch, dim=1)}

        beam_results = {}
        output_label = 'out_phn'
        assert output_label in self.model.out_names
        with KaldiOutputWriter(tmp_out_dir, "keyword", [output_label],
                               self.epoch) as writer:
            post_files.append(writer.post_file[output_label].name)
            for sample_name in tqdm(input_features,
                                    desc="computing acoustic features:",
                                    position=1):
                # input_feature = {"fbank": self.preprocess_feat(input_features[sample_name])}
                input_feature = {
                    "fbank":
                    torch.from_numpy(
                        input_features[sample_name].T).unsqueeze(0)
                }
                # Normalize over whole chunk instead of only over a single file, which is done by applying the kaldi cmvn
                input_feature["fbank"] = ((input_feature["fbank"] - mean) /
                                          std)

                # assert input_feature["fbank"].shape[2] > self.model.context_left + self.model.context_right + 50
                if input_feature["fbank"].shape[
                        2] < self.model.context_left + self.model.context_right + 100:
                    padd = torch.zeros(
                        (input_feature["fbank"].shape[0],
                         input_feature["fbank"].shape[1],
                         self.model.context_left + self.model.context_right),
                        device=input_feature["fbank"].device,
                        dtype=input_feature["fbank"].dtype)
                    input_feature["fbank"] = torch.cat(
                        (padd, input_feature["fbank"]), dim=2)

                output = self.model(input_feature)
                assert output_label in output
                output = output[output_label]

                _logits = output.detach().permute(0, 2, 1)

                output = output.detach().squeeze(0).numpy().T
                # output = test_output

                # if self.config['test'][output_label]['normalize_posteriors']:
                # counts = self.config['dataset']['dataset_definition']['data_info']['labels']['lab_phn']['lab_count']
                # counts = np.array(counts)
                # blank_count = sum(counts)  # heuristic sil * 2 for the moment
                # counts = counts * 0.5
                # counts = np.concatenate((np.array([np.e]), counts))
                # blank_scale = 1.0
                # TODO try different blank_scales 4.0 5.0 6.0 7.0
                # counts[0] /= blank_scale
                # for i in range(1, 8):
                #     counts[i] /= noise_scale #TODO try noise_scale for SIL SPN etc I guess

                # prior = counts / np.sum(counts)

                # output[:, 1:] = output[:, 1:] - np.log(prior)
                # assert _logits.shape[0] == batch_size
                # output = np.exp(output)

                beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(
                    _logits)
                beam_result = beam_result[0, 0, :out_seq_len[0, 0]]
                result_decoded = [
                    self.phoneme_dict.reducedIdx2phoneme[l.item() - 1]
                    for l in beam_result
                ]
                result_decoded = " ".join(result_decoded)

                beam_results[sample_name] = result_decoded

                if plot_num < 20 and plot_phns:
                    # logger.debug(sample_name)

                    # logger.debug(result_decoded)
                    # if plot_phns:
                    #     label_decoded = " ".join(
                    #         [self.phoneme_dict.idx2phoneme[l.item()] for l in _labels_no_ali['lab_mono'][sample_name]])
                    #     logger.debug(label_decoded)

                    # if plot_phns:
                    #     plot_alignment_spectrogram(sample_name, input_feature["fbank"],
                    #                                (np.exp(output).T / np.exp(output).sum(axis=1)).T,
                    #                                self.phoneme_dict, _labels, result_decoded=result_decoded)
                    # else:
                    plot_alignment_spectrogram(sample_name,
                                               input_feature["fbank"],
                                               (np.exp(output).T /
                                                np.exp(output).sum(axis=1)).T,
                                               self.phoneme_dict,
                                               result_decoded=result_decoded)

                    plot_num += 1
                # else:
                #     beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(_logits)
                #     beam_result = beam_result[0, 0, :out_seq_len[0, 0]]
                #     # logger.debug(sample_name)
                #     result_decoded = [self.phoneme_dict.reducedIdx2phoneme[l.item() - 1] for l in beam_result]
                #     result_decoded = " ".join(result_decoded)
                #     # logger.debug(result_decoded)
                #     plot_alignment_spectrogram(sample_name, input_feature["fbank"],
                #                                (np.exp(output).T / np.exp(output).sum(axis=1)).T,
                #                                self.phoneme_dict, metadata_dict[sample_name], result_decoded=result_decoded)
                #
                #     plot_num += 1

                assert len(output.shape) == 2
                assert np.sum(np.isnan(output)) == 0, "NaN in output"
                assert output.shape[1] == len(
                    self.phoneme_dict.reducedIdx2phoneme) + 1
                writer.write_mat(output_label, output.squeeze(), sample_name)

        # self.config['decoding']['scoring_type'] = 'just_transcript'
        #### DECODING ####
        logger.debug("Decoding...")
        result = decode_ctc(**self.config['dataset']['dataset_definition']
                            ['decoding'],
                            words_path=self.words_path,
                            graph_path=self.graph_path,
                            out_folder=tmp_out_dir,
                            featstrings=post_files)

        # TODO filter result

        return result
예제 #5
0
def evaluate(model,
             metrics,
             device,
             out_folder,
             exp_name,
             max_label_length,
             epoch,
             dataset_type,
             data_cache_root,
             test_with,
             all_feats_dict,
             features_use,
             all_labs_dict,
             labels_use,
             phoneme_dict,
             decoding_info,
             lab_graph_dir=None,
             tensorboard_logger=None):
    model.eval()
    batch_size = 1
    max_seq_length = -1

    accumulated_test_metrics = {metric: 0 for metric in metrics}

    test_data = test_with
    dataset = get_dataset(
        dataset_type,
        data_cache_root,
        f"{test_data}_{exp_name}",
        {feat: all_feats_dict[feat]
         for feat in features_use},
        {lab: all_labs_dict[lab]
         for lab in labels_use},
        max_seq_length,
        model.context_left,
        model.context_right,
        normalize_features=True,
        phoneme_dict=phoneme_dict,
        max_seq_len=max_seq_length,
        max_label_length=max_label_length)

    dataloader = KaldiDataLoader(dataset,
                                 batch_size,
                                 use_gpu=False,
                                 batch_ordering=model.batch_ordering)

    assert len(dataset) >= batch_size, \
        f"Length of test dataset {len(dataset)} too small " \
        + f"for batch_size of {batch_size}"

    n_steps_this_epoch = 0
    warned_size = False

    with Pool(os.cpu_count()) as pool:
        multip_process = Manager()
        metrics_q = multip_process.Queue(maxsize=os.cpu_count())
        # accumulated_test_metrics_future_list = pool.apply_async(metrics_accumulator, (metrics_q, metrics))
        accumulated_test_metrics_future_list = [
            pool.apply_async(metrics_accumulator, (metrics_q, metrics))
            for _ in range(os.cpu_count())
        ]
        with KaldiOutputWriter(out_folder, test_data, model.out_names,
                               epoch) as writer:
            with tqdm(disable=not logger.isEnabledFor(logging.INFO),
                      total=len(dataloader),
                      position=0) as pbar:
                pbar.set_description('E e:{}    '.format(epoch))
                for batch_idx, (sample_names, inputs,
                                targets) in enumerate(dataloader):
                    n_steps_this_epoch += 1

                    inputs = to_device(device, inputs)
                    if "lab_phn" not in targets:
                        targets = to_device(device, targets)

                    output = model(inputs)

                    output = detach_cpu(output)
                    targets = detach_cpu(targets)

                    #### Logging ####
                    metrics_q.put((output, targets))

                    pbar.set_description('E e:{} '.format(epoch))
                    pbar.update()
                    #### /Logging ####

                    warned_label = False
                    for output_label in output:
                        if output_label in model.out_names:
                            # squeeze that batch
                            output[output_label] = output[
                                output_label].squeeze(1)
                            # remove blank/padding 0th dim
                            # if config["arch"]["framewise_labels"] == "shuffled_frames":
                            out_save = output[output_label].data.cpu().numpy()
                            # else:
                            #     raise NotImplementedError("TODO make sure the right dimension is taken")
                            #     out_save = output[output_label][:, :-1].data.cpu().numpy()

                            if len(out_save.shape
                                   ) == 3 and out_save.shape[0] == 1:
                                out_save = out_save.squeeze(0)

                            if dataset.state.dataset_type != DatasetType.SEQUENTIAL_APPENDED_CONTEXT \
                                    and dataset.state.dataset_type != DatasetType.SEQUENTIAL:
                                raise NotImplementedError(
                                    "TODO rescaling with prior")

                            # if config['dataset']['dataset_definition']['decoding']['normalize_posteriors']:
                            #     # read the config file
                            #     counts = config['dataset']['dataset_definition'] \
                            #         ['data_info']['labels']['lab_phn']['lab_count']
                            #     if out_save.shape[-1] == len(counts) - 1:
                            #         if not warned_size:
                            #             logger.info(
                            #                 f"Counts length is {len(counts)} but output"
                            #                 + f" has size {out_save.shape[-1]}."
                            #                 + f" Assuming that counts is 1 indexed")
                            #             warned_size = True
                            #         counts = counts[1:]
                            #     # Normalize by output count
                            # #     if ctc:
                            # #         blank_scale = 1.0
                            # #         # TODO try different blank_scales 4.0 5.0 6.0 7.0
                            # #         counts[0] /= blank_scale
                            # #         # for i in range(1, 8):
                            # #         #     counts[i] /= noise_scale #TODO try noise_scale for SIL SPN etc I guess
                            # #
                            # #     prior = np.log(counts / np.sum(counts))
                            #
                            #     out_save = out_save - np.log(prior)

                            # shape == NC
                            assert len(out_save.shape) == 2
                            assert len(sample_names) == 1
                            writer.write_mat(output_label, out_save.squeeze(),
                                             sample_names[0])

                        else:
                            if not warned_label:
                                logger.debug(
                                    "Skipping saving forward for decoding for key {}"
                                    .format(output_label))
                                warned_label = True

            for _accumulated_test_metrics in accumulated_test_metrics_future_list:
                metrics_q.put(None)
            for _accumulated_test_metrics in accumulated_test_metrics_future_list:
                _accumulated_test_metrics = _accumulated_test_metrics.get()
                for metric, metric_value in _accumulated_test_metrics.items():
                    accumulated_test_metrics[metric] += metric_value

    # test_metrics = {metric: 0 for metric in metrics}
    # for metric in accumulated_test_metrics:
    #     for metric, metric_value in metric.items():
    #         test_metrics[metric] += metric_value

    test_metrics = {
        metric: accumulated_test_metrics[metric] / len(dataloader)
        for metric in accumulated_test_metrics
    }
    if tensorboard_logger is not None:
        tensorboard_logger.set_step(epoch, 'eval')
        for metric, metric_value in test_metrics.items():
            tensorboard_logger.add_scalar(
                metric, test_metrics[metric] / len(dataloader))

    # decoding_results = []
    #### DECODING ####
    # for out_lab in model.out_names:
    out_lab = model.out_names[0]  # TODO query from model or sth

    # forward_data_lst = config['data_use']['test_with'] #TODO multiple forward sets
    # forward_data_lst = [config['dataset']['data_use']['test_with']]
    # forward_dec_outs = config['test'][out_lab]['require_decoding']

    # for data in forward_data_lst:
    logger.debug('Decoding {} output {}'.format(test_with, out_lab))

    if out_lab == 'out_cd':
        _label = 'lab_cd'
    elif out_lab == 'out_phn':
        _label = 'lab_phn'
    else:
        raise NotImplementedError(out_lab)

    lab_field = all_labs_dict[_label]

    out_folder = os.path.abspath(out_folder)
    out_dec_folder = '{}/decode_{}_{}'.format(out_folder, test_with, out_lab)

    # logits_test_clean_100_ep006_out_phn.ark
    files_dec_list = glob(
        f'{out_folder}/exp_files/logits_{test_with}_ep*_{out_lab}.ark')

    if lab_graph_dir is None:
        lab_graph_dir = os.path.abspath(lab_field['lab_graph'])
    if _label == 'lab_phn':
        decode_ctc(data=os.path.abspath(lab_field['lab_data_folder']),
                   graphdir=lab_graph_dir,
                   out_folder=out_dec_folder,
                   featstrings=files_dec_list)
    elif _label == 'lab_cd':
        decode_ce(**decoding_info,
                  alidir=os.path.abspath(lab_field['label_folder']),
                  data=os.path.abspath(lab_field['lab_data_folder']),
                  graphdir=lab_graph_dir,
                  out_folder=out_dec_folder,
                  featstrings=files_dec_list)
    else:
        raise ValueError(_label)

    decoding_results = best_wer(out_dec_folder, decoding_info['scoring_type'])
    logger.info(decoding_results)

    tensorboard_logger.add_text("WER results", str(decoding_results))

    # TODO plotting curves

    return {'test_metrics': test_metrics, "decoding_results": decoding_results}
예제 #6
0
def _load_labels(label_dict, label_index_from, max_label_length, phoneme_dict):
    all_labels_loaded = {}

    for lable_name in label_dict:
        all_labels_loaded[lable_name] = load_labels(
            label_dict[lable_name]['label_folder'],
            label_dict[lable_name]['label_opts'])

        if max_label_length is not None and max_label_length > 0:
            all_labels_loaded[lable_name] = \
                {l: all_labels_loaded[lable_name][l] for l in all_labels_loaded[lable_name]
                 if len(all_labels_loaded[lable_name][l]) < max_label_length}

        if lable_name == "lab_phn":
            if phoneme_dict is not None:
                for sample_id in all_labels_loaded[lable_name]:
                    assert max(all_labels_loaded[lable_name][sample_id]) <= max(
                        phoneme_dict.idx2reducedIdx.keys()), \
                        "Are you sure you have the righ phoneme dict?" + \
                        " Labels have higher indices than phonemes ( {} <!= {} )".format(
                            max(all_labels_loaded[lable_name][sample_id]),
                            max(phoneme_dict.idx2reducedIdx.keys()))

                    # map labels according to phoneme dict
                    tmp_labels = np.copy(
                        all_labels_loaded[lable_name][sample_id])
                    for k, v in phoneme_dict.idx2reducedIdx.items():
                        tmp_labels[all_labels_loaded[lable_name][sample_id] ==
                                   k] = v

                    all_labels_loaded[lable_name][sample_id] = tmp_labels

        max_label = max([
            all_labels_loaded[lable_name][l].max()
            for l in all_labels_loaded[lable_name]
        ])
        min_label = min([
            all_labels_loaded[lable_name][l].min()
            for l in all_labels_loaded[lable_name]
        ])
        logger.debug(f"Max label: {max_label}")
        logger.debug(f"min label: {min_label}")

        if min_label > 0:
            logger.warn(
                f"label {lable_name} does not seem to be indexed from 0 -> making it indexed from 0"
            )
            for l in all_labels_loaded[lable_name]:
                all_labels_loaded[lable_name][
                    l] = all_labels_loaded[lable_name][l] - 1

            max_label = max([
                all_labels_loaded[lable_name][l].max()
                for l in all_labels_loaded[lable_name]
            ])
            min_label = min([
                all_labels_loaded[lable_name][l].min()
                for l in all_labels_loaded[lable_name]
            ])
            logger.debug(f"Max label new : {max_label}")
            logger.debug(f"min label new: {min_label}")

        if label_index_from != 0:
            assert label_index_from > 0
            all_labels_loaded[lable_name] = {
                filename:
                all_labels_loaded[lable_name][filename] + label_index_from
                for filename in all_labels_loaded[lable_name]
            }

    return all_labels_loaded
예제 #7
0
파일: viz_asr.py 프로젝트: pfriesch/PhnKWS
def valid_epoch_sync_metrics(epoch, model, loss_fun, metrics, config,
                             max_label_length, device, tensorboard_logger):
    model.eval()

    valid_loss = 0
    accumulated_valid_metrics = {metric: 0 for metric in metrics}

    valid_data = config['dataset']['data_use']['valid_with']
    _all_feats = config['dataset']['dataset_definition']['datasets'][
        valid_data]['features']
    _all_labs = config['dataset']['dataset_definition']['datasets'][
        valid_data]['labels']
    dataset = get_dataset(
        config['training']['dataset_type'],
        config['exp']['data_cache_root'],
        f"{valid_data}_{config['exp']['name']}",
        {feat: _all_feats[feat]
         for feat in config['dataset']['features_use']},
        {lab: _all_labs[lab]
         for lab in config['dataset']['labels_use']},
        config['training']['batching']['max_seq_length_valid'],
        model.context_left,
        model.context_right,
        normalize_features=True,
        phoneme_dict=config['dataset']['dataset_definition']['phoneme_dict'],
        max_seq_len=config['training']['batching']['max_seq_length_valid'],
        max_label_length=max_label_length)

    if config['training']['batching']['batch_size_valid'] != 1:
        logger.warn("setting valid batch size to 1 to avoid padding zeros")
    dataloader = KaldiDataLoader(
        dataset,
        config['training']['batching']['batch_size_valid'],
        config["exp"]["n_gpu"] > 0,
        batch_ordering=model.batch_ordering)

    assert len(dataset) >= config['training']['batching']['batch_size_valid'], \
        f"Length of valid dataset {len(dataset)} too small " \
        + f"for batch_size of {config['training']['batching']['batch_size_valid']}"

    n_steps_this_epoch = 0
    with tqdm(disable=not logger.isEnabledFor(logging.INFO),
              total=len(dataloader)) as pbar:
        pbar.set_description('V e:{} l: {} '.format(epoch, '-'))
        for batch_idx, (sample_name, inputs, targets) in enumerate(dataloader):
            n_steps_this_epoch += 1

            inputs = to_device(device, inputs)
            if "lab_phn" not in targets:
                targets = to_device(device, targets)

            output = model(inputs)
            loss = loss_fun(output, targets)

            output = detach_cpu(output)
            targets = detach_cpu(targets)
            loss = detach_cpu(loss)

            #### Logging ####
            valid_loss += loss["loss_final"].item()
            _valid_metrics = eval_metrics((output, targets), metrics)
            for metric, metric_value in _valid_metrics.items():
                accumulated_valid_metrics[metric] += metric_value

            pbar.set_description('V e:{} l: {:.4f} '.format(
                epoch, loss["loss_final"].item()))
            pbar.update()

            do_plotting = True
            if n_steps_this_epoch == 60 or n_steps_this_epoch == 1 and do_plotting:
                # raise NotImplementedError("TODO: add plots to tensorboard")
                output = output['out_phn']
                inputs = inputs["fbank"].numpy()
                _phoneme_dict = dataset.state.phoneme_dict
                vocabulary_size = len(
                    dataset.state.phoneme_dict.reducedIdx2phoneme) + 1
                vocabulary = [
                    chr(c) for c in list(range(65, 65 + 58)) +
                    list(range(65 + 58 + 69, 65 + 58 + 69 + 500))
                ][:vocabulary_size]
                decoder = ctcdecode.CTCBeamDecoder(vocabulary,
                                                   log_probs_input=True,
                                                   beam_width=1)

                decoder_logits = output.permute(0, 2, 1)
                # We expect batch x seq x label_size
                beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(
                    decoder_logits)

                _targets = []
                curr_l = 0
                for l in targets['target_sequence_lengths']:
                    _targets.append(targets['lab_phn'][curr_l:curr_l + l])
                    curr_l += l
                for i in range(len(inputs)):
                    _beam_result = beam_result[i, 0, :out_seq_len[i, 0]]
                    # logger.debug(sample_name)
                    result_decoded = [
                        _phoneme_dict.reducedIdx2phoneme[l.item() - 1]
                        for l in _beam_result
                    ]
                    result_decoded = " ".join(result_decoded)
                    logger.debug("RES: " + result_decoded)
                    # plot_phns = True
                    # if plot_phns:
                    label_decoded = " ".join([
                        _phoneme_dict.reducedIdx2phoneme[l.item() - 1]
                        for l in _targets[i]
                    ])
                    logger.debug("LAB: " + label_decoded)
                    text = sample_id_to_transcript(
                        sample_name[i],
                        "/mnt/data/datasets/LibriSpeech/dev-clean")
                    logger.debug("TXT: " + text)

                    # if plot_phns:
                    plot_alignment_spectrogram_ctc(
                        sample_name[i],
                        inputs[i],
                        (np.exp(output.numpy()[i]).T /
                         np.exp(output.numpy()[i]).sum(axis=1)).T,
                        _phoneme_dict,
                        label_decoded,
                        text,
                        result_decoded=result_decoded)
                    # else:
                    #     plot_alignment_spectrogram(sample_name, inputs["fbank"][i],
                    #                                (np.exp(output).T / np.exp(output).sum(axis=1)).T,
                    #                                _phoneme_dict, result_decoded=result_decoded)

            #### /Logging ####
    for metric, metric_value in accumulated_valid_metrics.items():
        accumulated_valid_metrics[metric] += metric_value

    tensorboard_logger.set_step(epoch, 'valid')
    tensorboard_logger.add_scalar('valid_loss',
                                  valid_loss / n_steps_this_epoch)
    for metric in accumulated_valid_metrics:
        tensorboard_logger.add_scalar(
            metric, accumulated_valid_metrics[metric] / n_steps_this_epoch)

    return {
        'valid_loss': valid_loss / n_steps_this_epoch,
        'valid_metrics': {
            metric: accumulated_valid_metrics[metric] / n_steps_this_epoch
            for metric in accumulated_valid_metrics
        }
    }
예제 #8
0
    def is_keyword_batch(self, input_features, sensitivity, tmp_out_dir=None):
        if tmp_out_dir is None:
            tmp_out_dir = self.out_dir

        all_samples_concat = None
        for sample_name, feat in tqdm(input_features.items()):
            if all_samples_concat is None:
                all_samples_concat = feat
            else:
                all_samples_concat = np.concatenate((all_samples_concat, feat))

        mean = torch.from_numpy(np.mean(all_samples_concat, axis=0)).to(dtype=torch.float32).unsqueeze(-1)
        std = torch.from_numpy(np.std(all_samples_concat, axis=0)).to(dtype=torch.float32).unsqueeze(-1)
        post_files = []

        post_files = []

        plot_stuff = False
        # if plot_stuff:
        #     lab_dict = {"lab_mono": {
        #         "label_folder": "/mnt/data/libs/kaldi/egs/librispeech/s5/exp/tri4b_ali_dev_clean_100/",
        #         "label_opts": "ali-to-phones --per-frame=true",
        #         "lab_data_folder": "/mnt/data/libs/kaldi/egs/librispeech/s5/data/dev_clean/",
        #         "lab_graph": "/mnt/data/libs/kaldi/egs/librispeech/s5/exp/tri4b/graph_tgsmall/"
        #     }}
        #     label_index_from = 1
        #     _labels = _load_labels(lab_dict, label_index_from, max_label_length=None, phoneme_dict=self.phoneme_dict)

        with KaldiOutputWriter(tmp_out_dir, "keyword", self.model.out_names, self.epoch) as writer:
            output_label = 'out_cd'
            post_files.append(writer.post_file[output_label].name)
            for sample_name in tqdm(input_features, desc="computing acoustic features:", position=1):
                input_feature = {"fbank": self.preprocess_feat(input_features[sample_name])}
                # Normalize over whole chunk instead of only over a single file, which is done by applying the kaldi cmvn
                input_feature["fbank"] = ((input_feature["fbank"] - mean) / std)
                if self.model.batch_ordering == "TNCL":
                    input_feature["fbank"] = input_feature["fbank"].permute(2, 0, 1).unsqueeze(3)
                _output = self.model(input_feature)

                assert output_label in _output
                output = _output[output_label]

                if self.model.batch_ordering == "NCL":
                    output = output.detach().squeeze(0).numpy()

                elif self.model.batch_ordering == "TNCL":
                    output = output.detach().squeeze(1).numpy().T

                # if self.config['test'][output_label]['normalize_posteriors']:
                # read the config file
                counts = self.config['dataset']['dataset_definition']['data_info']['labels']['lab_cd']['lab_count']
                if len(output) >= 3481:  # TODO make based on index from
                    output = output[1:] - np.log(counts / np.sum(counts)).reshape(-1, 1)
                else:
                    output = output - np.log(counts / np.sum(counts)).reshape(-1, 1)

                output = output.transpose()

                if plot_stuff:
                    # if plot_num < 5:
                    # output_exp = np.exp(_output['out_mono'].detach().squeeze(0).numpy())
                    # plot(sample_name, input_feature, output_exp, self.phoneme_dict.idx2phoneme,
                    #      _labels=_labels[sample_name])

                    output_exp = np.exp(_output['out_cd'].detach().squeeze(1).numpy())
                    plot(sample_name, input_feature, output_exp, self.phoneme_dict.idx2phoneme)
                    # plot_num += 1

                assert len(output.shape) == 2
                assert np.sum(np.isnan(output)) == 0, "NaN in output"
                writer.write_mat(output_label, output, sample_name)
        # self.config['decoding']['scoring_type'] = 'just_transcript'
        #### DECODING ####
        logger.debug("Decoding...")
        result = decode(**self.config['dataset']['dataset_definition']['decoding'],
                        alignment_model_path=self.alignment_model_path,
                        words_path=self.words_path,
                        graph_path=self.graph_path,
                        out_folder=tmp_out_dir,
                        featstrings=post_files)

        # TODO filter result

        return result