Example #1
0
def decode(args, dataset, model, priors, device='cpu'):
    '''
        Produce lattices from the input utterances.
    '''
    # This is all of the kaldi code we are calling. We are just piping out
    # out features to latgen-faster-mapped which does all of the lattice
    # generation.
    lat_output = '''ark:| copy-feats ark:- ark:- |\
    latgen-faster-mapped --min-active={} --max-active={} \
    --max-mem={} \
    --lattice-beam={} --beam={} \
    --acoustic-scale={} --allow-partial=true \
    --word-symbol-table={} \
    {} {} ark:- ark:- | lattice-scale --acoustic-scale={} ark:- ark:- |\
    gzip -c > {}/lat.{}.gz'''.format(args.min_active, args.max_active,
                                     args.max_mem, args.lattice_beam,
                                     args.beam, args.acoustic_scale,
                                     args.words_file, args.trans_mdl,
                                     args.hclg, args.post_decode_acwt,
                                     args.dumpdir, args.job)

    # Do the decoding (dumping senone posteriors)
    model.eval()
    with torch.no_grad():
        with kaldi_io.open_or_fd(lat_output, 'wb') as f:
            utt_mat = []
            prev_key = b''
            generator = evaluation_batches(dataset)
            # Each minibatch is guaranteed to have at most 1 utterance. We need
            # to append the output of subsequent minibatches corresponding to
            # the same utterances. These are stored in ``utt_mat'', which is
            # just a buffer to accumulate the posterior outputs of minibatches
            # corresponding to the same utterance. The posterior state
            # probabilities are normalized (subtraction in log space), by the
            # log priors in order to produce pseudo-likelihoods useable for
            # for lattice generation with latgen-faster-mapped
            for key, mat in decode_dataset(args,
                                           generator,
                                           model,
                                           device='cpu',
                                           output_idx=args.output_idx):
                if len(utt_mat) > 0 and key != prev_key:
                    kaldi_io.write_mat(f,
                                       np.concatenate(utt_mat,
                                                      axis=0)[:utt_length, :],
                                       key=prev_key.decode('utf-8'))
                    utt_mat = []
                utt_mat.append(mat - args.prior_scale * priors)
                prev_key = key
                utt_length = dataset.utt_lengths[key] // dataset.subsample

            # Flush utt_mat buffer at the end
            if len(utt_mat) > 0:
                kaldi_io.write_mat(f,
                                   np.concatenate(utt_mat,
                                                  axis=0)[:utt_length, :],
                                   key=prev_key.decode('utf-8'))
Example #2
0
def decorrupt(args, dataset, model, objective, device='cpu'):
    '''
        Produce lattices from the input utterances.
    '''
    model.eval()
    utt_mats = {}
    prev_key = b''
    stride = args.left_context + args.chunk_width + args.right_context
    delay = args.left_context
    generator = evaluation_batches(dataset, stride=stride, delay=delay)
    # Each minibatch is guaranteed to have at most 1 utterance. We need
    # to append the output of subsequent minibatches corresponding to
    # the same utterances. These are stored in ``utt_mat'', which is
    # just a buffer to accumulate the posterior outputs of minibatches
    # corresponding to the same utterance. The posterior state
    # probabilities are normalized (subtraction in log space), by the
    # log priors in order to produce pseudo-likelihoods useable for
    # for lattice generation with latgen-faster-mapped
    for i, (key, sgld_iter, mat, targets) in enumerate(
            decorrupt_dataset(args, generator, model, objective,
                              device=device)):
        print(f"key: {key} sgld_iter: {sgld_iter}")
        print(f"targets: {targets}")
        if sgld_iter not in utt_mats:
            utt_mats[sgld_iter] = []

        if len(utt_mats[sgld_iter]) > 0 and key != prev_key:
            utt_length = dataset.utt_lengths[prev_key]
            for sgld_iter_ in utt_mats:
                np.save(
                    '{}/{}.{}'.format(args.dumpdir, prev_key.decode('utf-8'),
                                      str(sgld_iter_)),
                    np.concatenate(utt_mats[sgld_iter_],
                                   axis=0)[:utt_length, :],
                )
            utt_mats = {sgld_iter: []}

        utt_mats[sgld_iter].append(mat)
        prev_key = key

    # Flush utt_mat buffer at the end
    if len(utt_mats) > 0:
        utt_length = dataset.utt_lengths[prev_key]
        if len(utt_mats[0]) > 0:
            for sgld_iter in utt_mats:
                np.save(
                    '{}/{}.{}'.format(args.dumpdir, prev_key.decode('utf-8'),
                                      str(sgld_iter)),
                    np.concatenate(utt_mats[sgld_iter],
                                   axis=0)[:utt_length, :],
                )
def forward(args, dataset, model, device='cpu'):
    model.eval()
    with torch.no_grad():
        utt_mat = []
        prev_key = b''
        generator = evaluation_batches(dataset)
        for key, mat in decode_dataset(args, generator, model, device=device):
            if len(utt_mat) > 0 and key != prev_key:
                np.save(
                    '{}/embeddings.{}'.format(args.dumpdir,
                                              prev_key.decode('utf-8')),
                    np.concatenate(utt_mat, axis=0)[:utt_length, :])
                utt_mat = []
            utt_mat.append(mat)
            prev_key = key
            utt_length = dataset.utt_lengths[key] // dataset.subsample
        if len(utt_mat) > 0:
            np.save(
                '{}/embeddings.{}'.format(args.dumpdir, key.decode('utf-8')),
                np.concatenate(utt_mat, axis=0)[:utt_length, :],
            )