示例#1
0
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    kwargs.pop("config", None)
    inference(**kwargs)
示例#2
0
def convert(jsonf, dic, refs, hyps, srcs, dic_src):

    # logging info
    logfmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info(get_commandline_args())

    logging.info("reading %s", jsonf)
    with codecs.open(jsonf, 'r', encoding="utf-8") as f:
        j = json.load(f)

    # target dictionary
    logging.info("reading %s", dic)
    with codecs.open(dic, 'r', encoding="utf-8") as f:
        dictionary = f.readlines()
    char_list_tgt = [entry.split(' ')[0] for entry in dictionary]
    char_list_tgt.insert(0, '<blank>')
    char_list_tgt.append('<eos>')

    # source dictionary
    logging.info("reading %s", dic_src)
    if dic_src:
        with codecs.open(dic_src, 'r', encoding="utf-8") as f:
            dictionary = f.readlines()
        char_list_src = [entry.split(' ')[0] for entry in dictionary]
        char_list_src.insert(0, '<blank>')
        char_list_src.append('<eos>')

    if hyps:
        hyp_file = codecs.open(hyps[0], 'w', encoding="utf-8")
    ref_file = codecs.open(refs[0], 'w', encoding="utf-8")
    if srcs:
        src_file = codecs.open(srcs[0], 'w', encoding="utf-8")

    for x in j['utts']:
        # hyps
        if hyps:
            seq = [char_list_tgt[int(i)] for i in j['utts'][x]['output'][0]['rec_tokenid'].split()]
            hyp_file.write(" ".join(seq).replace('<eos>', '')),
            hyp_file.write(" (" + j['utts'][x]['utt2spk'].replace('-', '_') + "-" + x + ")\n")

        # ref
        seq = [char_list_tgt[int(i)] for i in j['utts'][x]['output'][0]['tokenid'].split()]
        ref_file.write(" ".join(seq).replace('<eos>', '')),
        ref_file.write(" (" + j['utts'][x]['utt2spk'].replace('-', '_') + "-" + x + ")\n")

        # src
        if 'tokenid_src' in j['utts'][x]['output'][0].keys():
            if dic_src:
                seq = [char_list_src[int(i)] for i in j['utts'][x]['output'][0]['tokenid_src'].split()]
            else:
                seq = [char_list_tgt[int(i)] for i in j['utts'][x]['output'][0]['tokenid_src'].split()]
            src_file.write(" ".join(seq).replace('<eos>', '')),
            src_file.write(" (" + j['utts'][x]['utt2spk'].replace('-', '_') + "-" + x + ")\n")

    if hyps:
        hyp_file.close()
    ref_file.close()
    if srcs:
        src_file.close()
示例#3
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info('Apply preprocessing: {}'.format(preprocessing))
    else:
        preprocessing = None

    with file_writer_helper(args.wspecifier,
                            filetype=args.filetype,
                            write_num_frames=args.write_num_frames,
                            compress=args.compress,
                            compression_method=args.compression_method,
                            pcm_format=args.format) as writer:
        for utt_id, (rate,
                     array) in kaldiio.ReadHelper(args.rspecifier,
                                                  args.segments):
            if args.filetype == 'mat':
                # Kaldi-matrix doesn't support integer
                array = array.astype(numpy.float32)

            if array.ndim == 1:
                # (Time) -> (Time, Channel)
                array = array[:, None]

            if args.normalize is not None and args.normalize != 1:
                array = array.astype(numpy.float32)
                array = array / (1 << (args.normalize - 1))

            if preprocessing is not None:
                orgtype = array.dtype
                out = preprocessing(array, uttid_list=utt_id)
                out = out.astype(orgtype)

                if args.keep_length:
                    if len(out) > len(array):
                        out = numpy.pad(out, [(0, len(out) - len(array))] +
                                        [(0, 0) for _ in range(out.ndim - 1)],
                                        mode='constant')
                    elif len(out) < len(array):
                        # The length can be changed by stft, for example.
                        out = out[:len(out)]

                array = out

            # shape = (Time, Channel)
            if args.filetype in ['sound.hdf5', 'sound']:
                # Write Tuple[int, numpy.ndarray] (scipy style)
                writer[utt_id] = (rate, array)
            else:
                writer[utt_id] = array
示例#4
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info('Apply preprocessing: {}'.format(preprocessing))
    else:
        preprocessing = None

    # There are no necessary for matrix without preprocessing,
    # so change to file_reader_helper to return shape.
    # This make sense only with filetype="hdf5".
    for utt, mat in file_reader_helper(args.rspecifier, args.filetype,
                                       return_shape=preprocessing is None):
        if preprocessing is not None:
            if is_scipy_wav_style(mat):
                # If data is sound file, then got as Tuple[int, ndarray]
                rate, mat = mat
            mat = preprocessing(mat, uttid_list=utt)
            shape_str = ','.join(map(str, mat.shape))
        else:
            if len(mat) == 2 and isinstance(mat[1], tuple):
                # If data is sound file, Tuple[int, Tuple[int, ...]]
                rate, mat = mat
            shape_str = ','.join(map(str, mat))
        args.out.write('{} {}\n'.format(utt, shape_str))
示例#5
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    with kaldiio.ReadHelper(args.rspecifier,
                            segments=args.segments) as reader, \
            file_writer_helper(args.wspecifier,
                               filetype=args.filetype,
                               write_num_frames=args.write_num_frames,
                               compress=args.compress,
                               compression_method=args.compression_method
                               ) as writer:
        for utt_id, (_, array) in reader:
            array = array.astype(numpy.float32)
            if args.normalize is not None and args.normalize != 1:
                array = array / (1 << (args.normalize - 1))
            spc = spectrogram(x=array,
                              n_fft=args.n_fft,
                              n_shift=args.n_shift,
                              win_length=args.win_length,
                              window=args.window)
            writer[utt_id] = spc
def main():
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    )
    logging.info(get_commandline_args())

    # check directory
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    for idx, (utt_id, lmspc) in enumerate(
            file_reader_helper(args.rspecifier, args.filetype), 1):
        if args.n_mels is not None:
            spc = logmelspc_to_linearspc(lmspc,
                                         fs=args.fs,
                                         n_mels=args.n_mels,
                                         n_fft=args.n_fft,
                                         fmin=args.fmin,
                                         fmax=args.fmax)
        else:
            spc = lmspc
        y = griffin_lim(spc,
                        n_fft=args.n_fft,
                        n_shift=args.n_shift,
                        win_length=args.win_length,
                        window=args.window,
                        n_iters=args.iters)
        logging.info("(%d) %s" % (idx, utt_id))
        write(args.outdir + "/%s.wav" % utt_id, args.fs,
              (y * np.iinfo(np.int16).max).astype(np.int16))
示例#7
0
def convert(jsonf, dic, refs, hyps, num_spkrs=1):
    n_ref = len(refs)
    n_hyp = len(hyps)
    assert n_ref == n_hyp
    assert n_ref == num_spkrs

    # logging info
    logfmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info(get_commandline_args())

    logging.info("reading %s", jsonf)
    with codecs.open(jsonf, 'r', encoding="utf-8") as f:
        j = json.load(f)

    logging.info("reading %s", dic)
    with codecs.open(dic, 'r', encoding="utf-8") as f:
        dictionary = f.readlines()
    char_list = [entry.split(' ')[0] for entry in dictionary]
    char_list.insert(0, '<blank>')
    char_list.append('<eos>')

    for ns in range(num_spkrs):
        hyp_file = codecs.open(hyps[ns], 'w', encoding="utf-8")
        ref_file = codecs.open(refs[ns], 'w', encoding="utf-8")

        for x in j['utts']:
            # hyps
            if num_spkrs == 1:
                seq = [
                    char_list[int(i)]
                    for i in j['utts'][x]['output'][0]['rec_tokenid'].split()
                ]
            else:
                seq = [
                    char_list[int(i)] for i in j['utts'][x]['output'][ns][0]
                    ['rec_tokenid'].split()
                ]
            hyp_file.write(" ".join(seq).replace('<eos>', '')),
            hyp_file.write(" (" + j['utts'][x]['utt2spk'].replace('-', '_') +
                           "-" + x + ")\n")

            # ref
            if num_spkrs == 1:
                seq = [
                    char_list[int(i)]
                    for i in j['utts'][x]['output'][0]['tokenid'].split()
                ]
            else:
                seq = [
                    char_list[int(i)]
                    for i in j['utts'][x]['output'][ns][0]['tokenid'].split()
                ]
            ref_file.write(" ".join(seq).replace('<eos>', '')),
            ref_file.write(" (" + j['utts'][x]['utt2spk'].replace('-', '_') +
                           "-" + x + ")\n")

        hyp_file.close()
        ref_file.close()
示例#8
0
def main(cmd=None):
    """Parse arguments and start the alignment in ctc_align(·)."""
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    kwargs.pop("config", None)
    ctc_align(**kwargs)
示例#9
0
def convert(jsonf, dic, refs, hyps, num_spkrs=1):
    n_ref = len(refs)
    n_hyp = len(hyps)
    assert n_ref == n_hyp
    assert n_ref == num_spkrs

    # logging info
    logfmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info(get_commandline_args())

    logging.info("reading %s", jsonf)
    with codecs.open(jsonf, 'r', encoding="utf-8") as f:
        j = json.load(f)

    logging.info("reading %s", dic)
    with codecs.open(dic, 'r', encoding="utf-8") as f:
        dictionary = f.readlines()
    char_list = [entry.split(' ')[0] for entry in dictionary]
    char_list.insert(0, '<blank>')
    char_list.append('<eos>')

    for ns in range(num_spkrs):
        hyp_file = codecs.open(hyps[ns], 'w', encoding="utf-8")
        ref_file = codecs.open(refs[ns], 'w', encoding="utf-8")

        for x in j['utts']:
            # recognition hypothesis
            if num_spkrs == 1:
                seq = [
                    char_list[int(i)]
                    for i in j['utts'][x]['output'][0]['rec_tokenid'].split()
                ]
            else:
                seq = [
                    char_list[int(i)] for i in j['utts'][x]['output'][ns][0]
                    ['rec_tokenid'].split()
                ]
            # In the recognition hypothesis, the <eos> symbol is usually attached in the last part of the sentence
            # and it is removed below.
            hyp_file.write(" ".join(seq).replace('<eos>', '')),
            hyp_file.write(" (" + j['utts'][x]['utt2spk'].replace('-', '_') +
                           "-" + x + ")\n")

            # reference
            if num_spkrs == 1:
                seq = j['utts'][x]['output'][0]['token']
            else:
                seq = j['utts'][x]['output'][ns][0]['token']
            # Unlike the recognition hypothesis, the reference is directly generated from a token without dictionary
            # to avoid to include <unk> symbols in the reference to make scoring normal.
            # The detailed discussion can be found at https://github.com/espnet/espnet/issues/993
            ref_file.write(seq + " (" +
                           j['utts'][x]['utt2spk'].replace('-', '_') + "-" +
                           x + ")\n")

        hyp_file.close()
        ref_file.close()
示例#10
0
def main():
    args = get_parser().parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if ":" in args.stats_rspecifier_or_rxfilename:
        is_rspcifier = True
        if args.stats_filetype == "npy":
            stats_filetype = "hdf5"
        else:
            stats_filetype = args.stats_filetype

        stats_dict = dict(
            file_reader_helper(args.stats_rspecifier_or_rxfilename,
                               stats_filetype))
    else:
        is_rspcifier = False
        if args.stats_filetype == "mat":
            stats = kaldiio.load_mat(args.stats_rspecifier_or_rxfilename)
        else:
            stats = numpy.load(args.stats_rspecifier_or_rxfilename)
        stats_dict = {None: stats}

    cmvn = CMVN(
        stats=stats_dict,
        norm_means=args.norm_means,
        norm_vars=args.norm_vars,
        utt2spk=args.utt2spk,
        spk2utt=args.spk2utt,
        reverse=args.reverse,
    )

    with file_writer_helper(
            args.wspecifier,
            filetype=args.out_filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method,
    ) as writer:
        for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
            if is_scipy_wav_style(mat):
                # If data is sound file, then got as Tuple[int, ndarray]
                rate, mat = mat
            mat = cmvn(mat, utt if is_rspcifier else None)
            writer[utt] = mat
示例#11
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    # set logger
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if not os.path.exists(args.figdir):
        os.makedirs(args.figdir)

    with kaldiio.ReadHelper(args.rspecifier) as reader, \
            codecs.open(args.wspecifier, "w", encoding="utf-8") as f:
        for utt_id, (rate, array) in reader:
            assert rate == args.fs
            array = array.astype(numpy.float32)
            if args.normalize is not None and args.normalize != 1:
                array = array / (1 << (args.normalize - 1))
            array_trim, idx = librosa.effects.trim(
                y=array,
                top_db=args.threshold,
                frame_length=args.win_length,
                hop_length=args.shift_length
            )
            start, end = idx / args.fs

            # save figure
            plt.subplot(2, 1, 1)
            plt.plot(array)
            plt.title("Original")
            plt.subplot(2, 1, 2)
            plt.plot(array_trim)
            plt.title("Trim")
            plt.tight_layout()
            plt.savefig(args.figdir + "/" + utt_id + ".png")
            plt.close()

            # added minimum silence part
            start = max(0.0, start - args.min_silence)
            end = min(len(array) / args.fs, end + args.min_silence)

            # write to segments file
            segment = "%s %s %f %f\n" % (
                utt_id, utt_id, start, end
            )
            f.write(segment)
示例#12
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    # Find the number of utterances
    n_utt = sum(1 for line in open(args.segments))
    logging.info("%d utterances found to be processed." % n_utt)

    # Compute fbank features
    with kaldiio.ReadHelper(
            args.rspecifier,
            segments=args.segments) as reader, file_writer_helper(
                args.wspecifier,
                filetype=args.filetype,
                write_num_frames=args.write_num_frames,
                compress=args.compress,
                compression_method=args.compression_method,
            ) as writer:
        for i, struct in enumerate(reader, start=1):
            logging.info("processing %d/%d(%.2f%%)" %
                         (i, n_utt, 100 * i / n_utt))
            utt_id, (rate, array) = struct
            try:
                assert rate == args.fs
                array = array.astype(numpy.float32)
                if args.normalize is not None and args.normalize != 1:
                    array = array / (1 << (args.normalize - 1))

                lmspc = logmelspectrogram(
                    x=array,
                    fs=args.fs,
                    n_mels=args.n_mels,
                    n_fft=args.n_fft,
                    n_shift=args.n_shift,
                    win_length=args.win_length,
                    window=args.window,
                    fmin=args.fmin,
                    fmax=args.fmax,
                )
                writer[utt_id] = lmspc
            except:
                logging.warning("failed to compute fbank for utt_id=`%s`" %
                                utt_id)
示例#13
0
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)

    d = ModelDownloader(".cache/espnet")
    o = d.download_and_unpack(args.mdl_file)
    kwargs = vars(args)
    kwargs.update(o)
    kwargs.update(
        {'data_path_and_name_and_type': [(args.wav_scp, 'speech', 'sound')]})
    del args.mdl_file
    del args.wav_scp

    kwargs.pop("config", None)
    inference(**kwargs)
示例#14
0
def convert(jsonf, refs, hyps, num_spkrs=1):
    n_ref = len(refs)
    n_hyp = len(hyps)
    assert n_ref == n_hyp
    assert n_ref == num_spkrs

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info(get_commandline_args())

    logging.info("reading %s", jsonf)
    with codecs.open(jsonf, "r", encoding="utf-8") as f:
        j = json.load(f)

    for ns in range(num_spkrs):
        hyp_file = codecs.open(hyps[ns], "w", encoding="utf-8")
        ref_file = codecs.open(refs[ns], "w", encoding="utf-8")

        for x in j["utts"]:
            # recognition hypothesis
            if num_spkrs == 1:
                seq = j["utts"][x]["output"][0]["rec_text"].replace(
                    "<eos>", "")
            else:
                seq = j["utts"][x]["output"][ns][0]["rec_text"].replace(
                    "<eos>", "")
            # In the recognition hypothesis,
            # the <eos> symbol is usually attached in the last part of the sentence
            # and it is removed below.
            hyp_file.write(seq)
            hyp_file.write(" (" + x.replace("-", "_") + ")\n")

            # reference
            if num_spkrs == 1:
                seq = j["utts"][x]["output"][0]["text"]
            else:
                seq = j["utts"][x]["output"][ns][0]["text"]
            # Unlike the recognition hypothesis,
            # the reference is directly generated from a token without dictionary
            # to avoid to include <unk> symbols in the reference to make scoring normal.
            # The detailed discussion can be found at
            # https://github.com/espnet/espnet/issues/993
            ref_file.write(seq + " (" + x.replace("-", "_") + ")\n")

        hyp_file.close()
        ref_file.close()
示例#15
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    with VideoReader(args.rspecifier) as reader, file_writer_helper(
            args.wspecifier,
            filetype=args.filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method,
    ) as writer:
        for utt_id, v_feature in reader:
            writer[utt_id] = v_feature
示例#16
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info("Apply preprocessing: {}".format(preprocessing))
    else:
        preprocessing = None

    with file_writer_helper(
            args.wspecifier,
            filetype=args.out_filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method,
    ) as writer:
        for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
            if is_scipy_wav_style(mat):
                # If data is sound file, then got as Tuple[int, ndarray]
                rate, mat = mat

            if preprocessing is not None:
                mat = preprocessing(mat, uttid_list=utt)

            # shape = (Time, Channel)
            if args.out_filetype in ["sound.hdf5", "sound"]:
                # Write Tuple[int, numpy.ndarray] (scipy style)
                writer[utt] = (rate, mat)
            else:
                writer[utt] = mat
示例#17
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info(get_commandline_args())

    # make intersection set for utterance keys
    num_keys = 0
    js = {}
    for i, x in enumerate(args.jsons):
        with codecs.open(x, encoding="utf-8") as f:
            j = json.load(f)
        ks = j["utts"].keys()
        logging.debug(x + ": has " + str(len(ks)) + " utterances")

        num_keys += len(ks)
        if i > 0:
            for k in ks:
                js[k + "." + str(i)] = j["utts"][k]
        else:
            js = j["utts"]
        # js.update(j['utts'])

    # logging.info('new json has ' + str(len(js.keys())) + ' utterances')
    logging.info("new json has " + str(num_keys) + " utterances")

    # ensure "ensure_ascii=False", which is a bug
    jsonstring = json.dumps(
        {"utts": js},
        indent=4,
        sort_keys=True,
        ensure_ascii=False,
        separators=(",", ": "),
    )
    sys.stdout = codecs.getwriter("utf-8")(sys.stdout.buffer)
    print(jsonstring)
示例#18
0
def main():
    parser = argparse.ArgumentParser(
        description='Compute cepstral mean and '
        'variance normalization statistics'
        'If wspecifier provided: per-utterance by default, '
        'or per-speaker if'
        'spk2utt option provided; if wxfilename: global',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--spk2utt',
                        type=str,
                        help='A text file of speaker to utterance-list map. '
                        '(Don\'t give rspecifier format, such as '
                        '"ark:utt2spk")')
    parser.add_argument('--verbose',
                        '-V',
                        default=0,
                        type=int,
                        help='Verbose option')
    parser.add_argument('--in-filetype',
                        type=str,
                        default='mat',
                        choices=['mat', 'hdf5', 'sound.hdf5', 'sound'],
                        help='Specify the file format for the rspecifier. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--out-filetype',
                        type=str,
                        default='mat',
                        choices=['mat', 'hdf5', 'npy'],
                        help='Specify the file format for the wspecifier. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--preprocess-conf',
                        type=str,
                        default=None,
                        help='The configuration file for the pre-processing')
    parser.add_argument('rspecifier',
                        type=str,
                        help='Read specifier for feats. e.g. ark:some.ark')
    parser.add_argument('wspecifier_or_wxfilename',
                        type=str,
                        help='Write specifier. e.g. ark:some.ark')
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    is_wspecifier = ':' in args.wspecifier_or_wxfilename

    if is_wspecifier:
        if args.spk2utt is not None:
            logging.info('Performing as speaker CMVN mode')
            utt2spk_dict = {}
            with open(args.spk2utt) as f:
                for line in f:
                    spk, utts = line.rstrip().split(None, 1)
                    for utt in utts.split():
                        utt2spk_dict[utt] = spk

            def utt2spk(x):
                return utt2spk_dict[x]
        else:
            logging.info('Performing as utterance CMVN mode')

            def utt2spk(x):
                return x

        if args.out_filetype == 'npy':
            logging.warning('--out-filetype npy is allowed only for '
                            'Global CMVN mode, changing to hdf5')
            args.out_filetype = 'hdf5'

    else:
        logging.info('Performing as global CMVN mode')
        if args.spk2utt is not None:
            logging.warning('spk2utt is not used for global CMVN mode')

        def utt2spk(x):
            return None

        if args.out_filetype == 'hdf5':
            logging.warning('--out-filetype hdf5 is not allowed for '
                            'Global CMVN mode, changing to npy')
            args.out_filetype = 'npy'

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info('Apply preprocessing: {}'.format(preprocessing))
    else:
        preprocessing = None

    # Calculate stats for each speaker
    counts = {}
    sum_feats = {}
    square_sum_feats = {}

    idx = 0
    for idx, (utt, matrix) in enumerate(
            FileReaderWrapper(args.rspecifier, args.in_filetype), 1):
        if is_scipy_wav_style(matrix):
            # If data is sound file, then got as Tuple[int, ndarray]
            rate, matrix = matrix
        if preprocessing is not None:
            matrix = preprocessing(matrix, uttid_list=utt)

        spk = utt2spk(utt)

        # Init at the first seen of the spk
        if spk not in counts:
            counts[spk] = 0
            feat_shape = matrix.shape[1:]
            # Accumulate in double precision
            sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
            square_sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)

        counts[spk] += matrix.shape[0]
        sum_feats[spk] += matrix.sum(axis=0)
        square_sum_feats[spk] += (matrix**2).sum(axis=0)
    logging.info('Processed {} utterances'.format(idx))
    assert idx > 0, idx

    cmvn_stats = {}
    for spk in counts:
        feat_shape = sum_feats[spk].shape
        cmvn_shape = (2, feat_shape[0] + 1) + feat_shape[1:]
        _cmvn_stats = np.empty(cmvn_shape, dtype=np.float64)
        _cmvn_stats[0, :-1] = sum_feats[spk]
        _cmvn_stats[1, :-1] = square_sum_feats[spk]

        _cmvn_stats[0, -1] = counts[spk]
        _cmvn_stats[1, -1] = 0.

        # You can get the mean and std as following,
        # >>> N = _cmvn_stats[0, -1]
        # >>> mean = _cmvn_stats[0, :-1] / N
        # >>> std = np.sqrt(_cmvn_stats[1, :-1] / N - mean ** 2)

        cmvn_stats[spk] = _cmvn_stats

    # Per utterance or speaker CMVN
    if is_wspecifier:
        with FileWriterWrapper(args.wspecifier_or_wxfilename,
                               filetype=args.out_filetype) as writer:
            for spk, mat in cmvn_stats.items():
                writer[spk] = mat

    # Global CMVN
    else:
        matrix = cmvn_stats[None]
        if args.out_filetype == 'npy':
            np.save(args.wspecifier_or_wxfilename, matrix)
        elif args.out_filetype == 'mat':
            # Kaldi supports only matrix or vector
            kaldiio.save_mat(args.wspecifier_or_wxfilename, matrix)
        else:
            raise RuntimeError('Not supporting: --out-filetype {}'.format(
                args.out_filetype))
示例#19
0
def main():
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info(get_commandline_args())

    parser = argparse.ArgumentParser(
        description='Create waves list from "wav.scp"',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("scp")
    parser.add_argument("outdir")
    parser.add_argument(
        "--name",
        default="wav",
        help="Specify the prefix word of output file name "
        'such as "wav.scp"',
    )
    parser.add_argument("--segments", default=None)
    parser.add_argument(
        "--fs",
        type=humanfriendly_or_none,
        default=None,
        help="If the sampling rate specified, "
        "Change the sampling rate.",
    )
    parser.add_argument("--audio-format", default="wav")
    group = parser.add_mutually_exclusive_group()
    group.add_argument("--ref-channels", default=None, type=str2int_tuple)
    group.add_argument("--utt2ref-channels", default=None, type=str)
    args = parser.parse_args()

    out_num_samples = Path(args.outdir) / f"utt2num_samples"

    if args.ref_channels is not None:

        def utt2ref_channels(x) -> Tuple[int, ...]:
            return args.ref_channels

    elif args.utt2ref_channels is not None:
        utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)

        def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
            chs_str = d[x]
            return tuple(map(int, chs_str.split()))

    else:
        utt2ref_channels = None

    if args.segments is not None:
        # Note: kaldiio supports only wav-pcm-int16le file.
        loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
        with SoundScpWriter(
                args.outdir,
                Path(args.outdir) / f"{args.name}.scp",
                format=args.audio_format,
        ) as writer, out_num_samples.open("w") as fnum_samples:
            for uttid, (rate, wave) in tqdm(loader):
                # wave: (Time,) or (Time, Nmic)
                if wave.ndim == 2 and utt2ref_channels is not None:
                    wave = wave[:, utt2ref_channels(uttid)]

                if args.fs is not None and args.fs != rate:
                    # FIXME(kamo): To use sox?
                    wave = resampy.resample(wave.astype(np.float64),
                                            rate,
                                            args.fs,
                                            axis=0)
                    wave = wave.astype(np.int16)
                    rate = args.fs
                writer[uttid] = rate, wave
                fnum_samples.write(f"{uttid} {len(wave)}\n")
    else:
        wavdir = Path(args.outdir) / f"data_{args.name}"
        wavdir.mkdir(parents=True, exist_ok=True)
        out_wavscp = Path(args.outdir) / f"{args.name}.scp"

        with Path(args.scp).open("r") as fscp, out_wavscp.open(
                "w") as fout, out_num_samples.open("w") as fnum_samples:
            for line in tqdm(fscp):
                uttid, wavpath = line.strip().split(None, 1)

                if wavpath.endswith("|"):
                    # Streaming input e.g. cat a.wav |
                    with kaldiio.open_like_kaldi(wavpath, "rb") as f:
                        with BytesIO(f.read()) as g:
                            wave, rate = soundfile.read(g, dtype=np.int16)
                            if wave.ndim == 2 and utt2ref_channels is not None:
                                wave = wave[:, utt2ref_channels(uttid)]

                        if args.fs is not None and args.fs != rate:
                            # FIXME(kamo): To use sox?
                            wave = resampy.resample(wave.astype(np.float64),
                                                    rate,
                                                    args.fs,
                                                    axis=0)
                            wave = wave.astype(np.int16)
                            rate = args.fs

                        owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
                        soundfile.write(owavpath, wave, rate)
                        fout.write(f"{uttid} {owavpath}\n")
                else:
                    wave, rate = soundfile.read(wavpath, dtype=np.int16)
                    if wave.ndim == 2 and utt2ref_channels is not None:
                        wave = wave[:, utt2ref_channels(uttid)]
                        save_asis = False

                    elif Path(wavpath).suffix == "." + args.audio_format and (
                            args.fs is None or args.fs == rate):
                        save_asis = True

                    else:
                        save_asis = False

                    if save_asis:
                        # Neither --segments nor --fs are specified and
                        # the line doesn't end with "|",
                        # i.e. not using unix-pipe,
                        # only in this case,
                        # just using the original file as is.
                        fout.write(f"{uttid} {wavpath}\n")
                    else:
                        if args.fs is not None and args.fs != rate:
                            # FIXME(kamo): To use sox?
                            wave = resampy.resample(wave.astype(np.float64),
                                                    rate,
                                                    args.fs,
                                                    axis=0)
                            wave = wave.astype(np.int16)
                            rate = args.fs

                        owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
                        soundfile.write(owavpath, wave, rate)
                        fout.write(f"{uttid} {owavpath}\n")
                fnum_samples.write(f"{uttid} {len(wave)}\n")
示例#20
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    # logging info
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    logging.info(get_commandline_args())

    # check directory
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # load model config
    model_dir = os.path.dirname(args.model)
    train_args = torch.load(os.path.join(model_dir, "model.conf"))

    # load statistics
    scaler = StandardScaler()
    with h5py.File(os.path.join(model_dir, "stats.h5")) as f:
        scaler.mean_ = f["/melspc/mean"][()]
        scaler.scale_ = f["/melspc/scale"][()]
        # TODO(kan-bayashi): include following info as default
        coef = f["/mlsa/coef"][()]
        alpha = f["/mlsa/alpha"][()]

    # define MLSA filter for noise shaping
    mlsa_filter = TimeInvariantMLSAFilter(
        coef=coef,
        alpha=alpha,
        n_shift=args.n_shift,
    )

    # define model and laod parameters
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = WaveNet(
        n_quantize=train_args.n_quantize,
        n_aux=train_args.n_aux,
        n_resch=train_args.n_resch,
        n_skipch=train_args.n_skipch,
        dilation_depth=train_args.dilation_depth,
        dilation_repeat=train_args.dilation_repeat,
        kernel_size=train_args.kernel_size,
        upsampling_factor=train_args.upsampling_factor,
    )
    model.load_state_dict(torch.load(args.model, map_location="cpu")["model"])
    model.eval()
    model.to(device)

    for idx, (utt_id, lmspc) in enumerate(
            file_reader_helper(args.rspecifier, args.filetype), 1):
        logging.info("(%d) %s" % (idx, utt_id))

        # perform preprocesing
        x = encode_mu_law(np.zeros(
            (1)), mu=train_args.n_quantize)  # quatize initial seed waveform
        h = scaler.transform(lmspc)  # normalize features

        # convert to tensor
        x = torch.tensor(x, dtype=torch.long, device=device)  # (1,)
        h = torch.tensor(h, dtype=torch.float, device=device)  # (T, n_aux)

        # get length of waveform
        n_samples = (h.shape[0] - 1) * args.n_shift + args.n_fft

        # generate
        start_time = time.time()
        with torch.no_grad():
            y = model.generate(x, h, n_samples, interval=100)
        logging.info("generation speed = %s (sec / sample)" %
                     ((time.time() - start_time) / (len(y) - 1)))
        y = decode_mu_law(y, mu=train_args.n_quantize)

        # apply mlsa filter for noise shaping
        y = mlsa_filter(y)

        # save as .wav file
        write(
            os.path.join(args.outdir, "%s.wav" % utt_id),
            args.fs,
            (y * np.iinfo(np.int16).max).astype(np.int16),
        )
示例#21
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--verbose',
                        '-V',
                        default=0,
                        type=int,
                        help='Verbose option')
    parser.add_argument('--filetype',
                        type=str,
                        default='mat',
                        choices=['mat', 'hdf5', 'sound.hdf5', 'sound'],
                        help='Specify the file format for the rspecifier. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--preprocess-conf',
                        type=str,
                        default=None,
                        help='The configuration file for the pre-processing')
    parser.add_argument('rspecifier',
                        type=str,
                        help='Read specifier for feats. e.g. ark:some.ark')
    parser.add_argument('out',
                        nargs='?',
                        type=argparse.FileType('w'),
                        default=sys.stdout,
                        help='The output filename. '
                        'If omitted, then output to sys.stdout')

    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info('Apply preprocessing: {}'.format(preprocessing))
    else:
        preprocessing = None

    # There are no necessary for matrix without preprocessing,
    # so change to FileReaderWrapper to return shape.
    # This make sense only with filetype="hdf5".
    for utt, mat in FileReaderWrapper(args.rspecifier,
                                      args.filetype,
                                      return_shape=preprocessing is None):
        if preprocessing is not None:
            if is_scipy_wav_style(mat):
                # If data is sound file, then got as Tuple[int, ndarray]
                rate, mat = mat
            mat = preprocessing(mat, uttid_list=utt)
            shape_str = ','.join(map(str, mat.shape))
        else:
            if len(mat) == 2 and isinstance(mat[1], tuple):
                # If data is sound file, Tuple[int, Tuple[int, ...]]
                rate, mat = mat
            shape_str = ','.join(map(str, mat))
        args.out.write('{} {}\n'.format(utt, shape_str))
示例#22
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--verbose',
                        '-V',
                        default=0,
                        type=int,
                        help='Verbose option')
    parser.add_argument('--in-filetype',
                        type=str,
                        default='mat',
                        choices=['mat', 'hdf5', 'sound.hdf5', 'sound'],
                        help='Specify the file format for the rspecifier. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--out-filetype',
                        type=str,
                        default='mat',
                        choices=['mat', 'hdf5'],
                        help='Specify the file format for the wspecifier. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--write-num-frames',
                        type=str,
                        help='Specify wspecifer for utt2num_frames')
    parser.add_argument('--compress',
                        type=strtobool,
                        default=False,
                        help='Save in compressed format')
    parser.add_argument(
        '--compression-method',
        type=int,
        default=2,
        help='Specify the method(if mat) or gzip-level(if hdf5)')
    parser.add_argument('--preprocess-conf',
                        type=str,
                        default=None,
                        help='The configuration file for the pre-processing')
    parser.add_argument('rspecifier',
                        type=str,
                        help='Read specifier for feats. e.g. ark:some.ark')
    parser.add_argument('wspecifier',
                        type=str,
                        help='Write specifier. e.g. ark:some.ark')
    args = parser.parse_args()

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info('Apply preprocessing: {}'.format(preprocessing))
    else:
        preprocessing = None

    with FileWriterWrapper(
            args.wspecifier,
            filetype=args.out_filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method) as writer:
        for utt, mat in FileReaderWrapper(args.rspecifier, args.in_filetype):
            if is_scipy_wav_style(mat):
                # If data is sound file, then got as Tuple[int, ndarray]
                rate, mat = mat
            if preprocessing is not None:
                mat = preprocessing(mat, uttid_list=utt)
            writer[utt] = mat
示例#23
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--write-num-frames',
                        type=str,
                        help='Specify wspecifer for utt2num_frames')
    parser.add_argument('--filetype',
                        type=str,
                        default='mat',
                        choices=['mat', 'hdf5', 'sound.hdf5', 'sound'],
                        help='Specify the file format for output. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--compress',
                        type=strtobool,
                        default=False,
                        help='Save in compressed format')
    parser.add_argument(
        '--compression-method',
        type=int,
        default=2,
        help='Specify the method(if mat) or gzip-level(if hdf5)')
    parser.add_argument('--verbose',
                        '-V',
                        default=0,
                        type=int,
                        help='Verbose option')
    parser.add_argument('--normalize',
                        choices=[1, 16, 24, 32],
                        type=int,
                        default=None,
                        help='Give the bit depth of the PCM, '
                        'then normalizes data to scale in [-1,1]')
    parser.add_argument('rspecifier', type=str, nargs='+', help='WAV scp file')
    parser.add_argument('--segments',
                        type=str,
                        help='segments-file format: each line is either'
                        '<segment-id> <recording-id> <start-time> <end-time>'
                        'e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5')
    parser.add_argument('wspecifier', type=str, help='Write specifier')
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    with FileWriterWrapper(
            args.wspecifier,
            filetype=args.filetype,
            write_num_frames=args.write_num_frames,
            compress=args.compress,
            compression_method=args.compression_method) as writer:
        for utt_id, (rate, array) in wav_generator(args.rspecifier,
                                                   args.segments):
            if args.filetype == 'mat':
                # Kaldi-matrix doesn't support integer
                array = array.astype(numpy.float32)

            if args.normalize is not None and args.normalize != 1:
                array = array.astype(numpy.float32)
                array = array / (1 << (args.normalize - 1))

            # shape = (Time, Channel)
            if args.filetype == 'sound.hdf5':
                # Write Tuple[int, numpy.ndarray] (scipy style)
                writer[utt_id] = (rate, array)
            else:
                writer[utt_id] = array
示例#24
0
def main(cmd=None):
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info(get_commandline_args())

    parser = get_parser()
    args = parser.parse_args(cmd)
    args.cmd = shlex.split(args.cmd)

    if args.host is None and shutil.which(args.cmd[0]) is None:
        raise RuntimeError(
            f"The first args of --cmd should be a script path. e.g. utils/run.pl: "
            f"{args.cmd[0]}"
        )

    # Specify init_method:
    #   See: https://pytorch.org/docs/stable/distributed.html#initialization
    if args.host is None and args.num_nodes <= 1:
        # Automatically set init_method if num_node=1
        init_method = None
    else:
        if args.master_port is None:
            # Try "shared-file system initialization" if master_port is not specified
            # Give random name to avoid reusing previous file
            init_file = args.init_file_prefix + str(uuid.uuid4())
            init_file = Path(init_file).absolute()
            Path(init_file).parent.mkdir(exist_ok=True, parents=True)
            init_method = ["--dist_init_method", f"file://{init_file}"]
        else:
            init_method = ["--dist_master_port", str(args.master_port)]

            # This can be omitted if slurm mode
            if args.master_addr is not None:
                init_method += ["--dist_master_addr", args.master_addr]
            elif args.host is not None:
                init_method += [
                    "--dist_master_addr",
                    args.host.split(",")[0].split(":")[0],
                ]

    # Log-rotation
    for i in range(args.max_num_log_files - 1, -1, -1):
        if i == 0:
            p = Path(args.log)
            pn = p.parent / (p.stem + f".1" + p.suffix)
        else:
            _p = Path(args.log)
            p = _p.parent / (_p.stem + f".{i}" + _p.suffix)
            pn = _p.parent / (_p.stem + f".{i + 1}" + _p.suffix)

        if p.exists():
            if i == args.max_num_log_files - 1:
                p.unlink()
            else:
                shutil.move(p, pn)

    processes = []
    # Submit command via SSH
    if args.host is not None:
        hosts = []
        ids_list = []
        # e.g. args.host = "host1:0:2,host2:0:1"
        for host in args.host.split(","):
            # e.g host = "host1:0:2"
            sps = host.split(":")
            host = sps[0]
            if len(sps) > 1:
                ids = [int(x) for x in sps[1:]]
            else:
                ids = list(range(args.ngpu))
            hosts.append(host)
            ids_list.append(ids)

        world_size = sum(max(len(x), 1) for x in ids_list)
        logging.info(f"{len(hosts)}nodes with world_size={world_size} via SSH")

        if args.envfile is not None:
            env = f"source {args.envfile}"
        else:
            env = ""

        if args.log != "-":
            Path(args.log).parent.mkdir(parents=True, exist_ok=True)
            f = Path(args.log).open("w", encoding="utf-8")
        else:
            # Output to stdout/stderr
            f = None

        rank = 0
        for host, ids in zip(hosts, ids_list):
            ngpu = 1 if len(ids) > 0 else 0
            ids = ids if len(ids) > 0 else ["none"]

            for local_rank in ids:
                cmd = (
                    args.args
                    + [
                        "--ngpu",
                        str(ngpu),
                        "--multiprocessing_distributed",
                        "false",
                        "--local_rank",
                        str(local_rank),
                        "--dist_rank",
                        str(rank),
                        "--dist_world_size",
                        str(world_size),
                    ]
                    + init_method
                )
                if ngpu == 0:
                    # Gloo supports both GPU and CPU mode.
                    #   See: https://pytorch.org/docs/stable/distributed.html
                    cmd += ["--dist_backend", "gloo"]

                heredoc = f"""<< EOF
set -euo pipefail
cd {os.getcwd()}
{env}
{" ".join([c if len(c) != 0 else "''" for c in cmd])}
EOF
"""

                # FIXME(kamo): The process will be alive
                #  even if this program is stopped because we don't set -t here,
                #  i.e. not assigning pty,
                #  and the program is not killed when SSH connection is closed.
                process = subprocess.Popen(
                    ["ssh", host, "bash", heredoc], stdout=f, stderr=f,
                )

                processes.append(process)

                rank += 1

    # If Single node
    elif args.num_nodes <= 1:
        if args.ngpu > 1:
            if args.multiprocessing_distributed:
                # NOTE:
                #   If multiprocessing_distributed=true,
                # -> Distributed mode, which is multi-process and Multi-GPUs.
                #    and TCP initializetion is used if single-node case:
                #      e.g. init_method="tcp://localhost:20000"
                logging.info(f"single-node with {args.ngpu}gpu on distributed mode")
            else:
                # NOTE:
                #   If multiprocessing_distributed=false
                # -> "DataParallel" mode, which is single-process
                #    and Multi-GPUs with threading.
                # See:
                # https://discuss.pytorch.org/t/why-torch-nn-parallel-distributeddataparallel-runs-faster-than-torch-nn-dataparallel-on-single-machine-with-multi-gpu/32977/2
                logging.info(f"single-node with {args.ngpu}gpu using DataParallel")

        # Using cmd as it is simply
        cmd = (
            args.cmd
            # arguments for ${cmd}
            + ["--gpu", str(args.ngpu), args.log]
            # arguments for *_train.py
            + args.args
            + [
                "--ngpu",
                str(args.ngpu),
                "--multiprocessing_distributed",
                str(args.multiprocessing_distributed),
            ]
        )
        process = subprocess.Popen(cmd)
        processes.append(process)

    elif Path(args.cmd[0]).name == "run.pl":
        raise RuntimeError("run.pl doesn't support submitting to the other nodes.")

    elif Path(args.cmd[0]).name == "ssh.pl":
        raise RuntimeError("Use --host option instead of ssh.pl")

    # If Slurm
    elif Path(args.cmd[0]).name == "slurm.pl":
        logging.info(f"{args.num_nodes}nodes and {args.ngpu}gpu-per-node using srun")
        cmd = (
            args.cmd
            # arguments for ${cmd}
            + [
                "--gpu",
                str(args.ngpu),
                "--num_threads",
                str(max(args.ngpu, 1)),
                "--num_nodes",
                str(args.num_nodes),
                args.log,
                "srun",
                # Inherit all enviroment variable from parent process
                "--export=ALL",
            ]
            # arguments for *_train.py
            + args.args
            + [
                "--ngpu",
                str(args.ngpu),
                "--multiprocessing_distributed",
                "true",
                "--dist_launcher",
                "slurm",
            ]
            + init_method
        )
        if args.ngpu == 0:
            # Gloo supports both GPU and CPU mode.
            #   See: https://pytorch.org/docs/stable/distributed.html
            cmd += ["--dist_backend", "gloo"]
        process = subprocess.Popen(cmd)
        processes.append(process)

    else:
        # This pattern can also works with Slurm.

        logging.info(f"{args.num_nodes}nodes and {args.ngpu}gpu-per-node using mpirun")
        cmd = (
            args.cmd
            # arguments for ${cmd}
            + [
                "--gpu",
                str(args.ngpu),
                "--num_threads",
                str(max(args.ngpu, 1)),
                # Make sure scheduler setting, i.e. conf/queue.conf
                # so that --num_nodes requires 1process-per-node
                "--num_nodes",
                str(args.num_nodes),
                args.log,
                "mpirun",
                # -np option can be omitted with Torque/PBS
                "-np",
                str(args.num_nodes),
            ]
            # arguments for *_train.py
            + args.args
            + [
                "--ngpu",
                str(args.ngpu),
                "--multiprocessing_distributed",
                "true",
                "--dist_launcher",
                "mpi",
            ]
            + init_method
        )
        if args.ngpu == 0:
            # Gloo supports both GPU and CPU mode.
            #   See: https://pytorch.org/docs/stable/distributed.html
            cmd += ["--dist_backend", "gloo"]
        process = subprocess.Popen(cmd)
        processes.append(process)

    logging.info(f"log file: {args.log}")

    failed = False
    while any(p.returncode is None for p in processes):
        for process in processes:
            # If any process is failed, try to kill the other processes too
            if failed and process.returncode is not None:
                process.kill()
            else:
                try:
                    process.wait(0.5)
                except subprocess.TimeoutExpired:
                    pass

                if process.returncode is not None and process.returncode != 0:
                    failed = True

    for process in processes:
        if process.returncode != 0:
            print(
                subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd),
                file=sys.stderr,
            )
            p = Path(args.log)
            if p.exists():
                with p.open() as f:
                    lines = list(f)
                raise RuntimeError(
                    f"\n################### The last 1000 lines of {args.log} "
                    f"###################\n" + "".join(lines[-1000:])
                )
            else:
                raise RuntimeError
示例#25
0
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    split_scps(**kwargs)
示例#26
0
def main():
    args = get_parser().parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    is_wspecifier = ":" in args.wspecifier_or_wxfilename

    if is_wspecifier:
        if args.spk2utt is not None:
            logging.info("Performing as speaker CMVN mode")
            utt2spk_dict = {}
            with open(args.spk2utt) as f:
                for line in f:
                    spk, utts = line.rstrip().split(None, 1)
                    for utt in utts.split():
                        utt2spk_dict[utt] = spk

            def utt2spk(x):
                return utt2spk_dict[x]

        else:
            logging.info("Performing as utterance CMVN mode")

            def utt2spk(x):
                return x

        if args.out_filetype == "npy":
            logging.warning("--out-filetype npy is allowed only for "
                            "Global CMVN mode, changing to hdf5")
            args.out_filetype = "hdf5"

    else:
        logging.info("Performing as global CMVN mode")
        if args.spk2utt is not None:
            logging.warning("spk2utt is not used for global CMVN mode")

        def utt2spk(x):
            return None

        if args.out_filetype == "hdf5":
            logging.warning("--out-filetype hdf5 is not allowed for "
                            "Global CMVN mode, changing to npy")
            args.out_filetype = "npy"

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info("Apply preprocessing: {}".format(preprocessing))
    else:
        preprocessing = None

    # Calculate stats for each speaker
    counts = {}
    sum_feats = {}
    square_sum_feats = {}

    idx = 0
    for idx, (utt, matrix) in enumerate(
            file_reader_helper(args.rspecifier, args.in_filetype), 1):
        if is_scipy_wav_style(matrix):
            # If data is sound file, then got as Tuple[int, ndarray]
            rate, matrix = matrix
        if preprocessing is not None:
            matrix = preprocessing(matrix, uttid_list=utt)

        spk = utt2spk(utt)

        # Init at the first seen of the spk
        if spk not in counts:
            counts[spk] = 0
            feat_shape = matrix.shape[1:]
            # Accumulate in double precision
            sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
            square_sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)

        counts[spk] += matrix.shape[0]
        sum_feats[spk] += matrix.sum(axis=0)
        square_sum_feats[spk] += (matrix**2).sum(axis=0)
    logging.info("Processed {} utterances".format(idx))
    assert idx > 0, idx

    cmvn_stats = {}
    for spk in counts:
        feat_shape = sum_feats[spk].shape
        cmvn_shape = (2, feat_shape[0] + 1) + feat_shape[1:]
        _cmvn_stats = np.empty(cmvn_shape, dtype=np.float64)
        _cmvn_stats[0, :-1] = sum_feats[spk]
        _cmvn_stats[1, :-1] = square_sum_feats[spk]

        _cmvn_stats[0, -1] = counts[spk]
        _cmvn_stats[1, -1] = 0.0

        # You can get the mean and std as following,
        # >>> N = _cmvn_stats[0, -1]
        # >>> mean = _cmvn_stats[0, :-1] / N
        # >>> std = np.sqrt(_cmvn_stats[1, :-1] / N - mean ** 2)

        cmvn_stats[spk] = _cmvn_stats

    # Per utterance or speaker CMVN
    if is_wspecifier:
        with file_writer_helper(args.wspecifier_or_wxfilename,
                                filetype=args.out_filetype) as writer:
            for spk, mat in cmvn_stats.items():
                writer[spk] = mat

    # Global CMVN
    else:
        matrix = cmvn_stats[None]
        if args.out_filetype == "npy":
            np.save(args.wspecifier_or_wxfilename, matrix)
        elif args.out_filetype == "mat":
            # Kaldi supports only matrix or vector
            kaldiio.save_mat(args.wspecifier_or_wxfilename, matrix)
        else:
            raise RuntimeError("Not supporting: --out-filetype {}".format(
                args.out_filetype))
示例#27
0
                        'If omitted, then output to sys.stdout')
    return parser


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    args.scps = [args.scps]

    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    # List[List[Tuple[str, str, Callable[[str], Any], str, str]]]
    input_infos = []
    output_infos = []
    infos = []
    for lis_list, key_scps_list in [(input_infos, args.input_scps),
                                    (output_infos, args.output_scps),
                                    (infos, args.scps)]:
        for key_scps in key_scps_list:
            lis = []
            for key_scp in key_scps:
                sps = key_scp.split(':')
                if len(sps) == 2:
                    key, scp = sps
                    type_func = None
示例#28
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--write-num-frames',
                        type=str,
                        help='Specify wspecifer for utt2num_frames')
    parser.add_argument('--filetype',
                        type=str,
                        default='mat',
                        choices=['mat', 'hdf5', 'sound.hdf5', 'sound'],
                        help='Specify the file format for output. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--format',
                        type=str,
                        default=None,
                        help='The file format for output pcm. '
                        'This option is only valid '
                        'when "--filetype" is "sound.hdf5" or "sound"')
    parser.add_argument('--compress',
                        type=strtobool,
                        default=False,
                        help='Save in compressed format')
    parser.add_argument(
        '--compression-method',
        type=int,
        default=2,
        help='Specify the method(if mat) or gzip-level(if hdf5)')
    parser.add_argument('--verbose',
                        '-V',
                        default=0,
                        type=int,
                        help='Verbose option')
    parser.add_argument('--normalize',
                        choices=[1, 16, 24, 32],
                        type=int,
                        default=None,
                        help='Give the bit depth of the PCM, '
                        'then normalizes data to scale in [-1,1]')
    parser.add_argument('--preprocess-conf',
                        type=str,
                        default=None,
                        help='The configuration file for the pre-processing')
    parser.add_argument('--keep-length',
                        type=strtobool,
                        default=True,
                        help='Truncating or zero padding if the output length '
                        'is changed from the input by preprocessing')
    parser.add_argument('rspecifier', type=str, help='WAV scp file')
    parser.add_argument('--segments',
                        type=str,
                        help='segments-file format: each line is either'
                        '<segment-id> <recording-id> <start-time> <end-time>'
                        'e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5')
    parser.add_argument('wspecifier', type=str, help='Write specifier')
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())

    if args.preprocess_conf is not None:
        preprocessing = Transformation(args.preprocess_conf)
        logging.info('Apply preprocessing: {}'.format(preprocessing))
    else:
        preprocessing = None

    with FileWriterWrapper(args.wspecifier,
                           filetype=args.filetype,
                           write_num_frames=args.write_num_frames,
                           compress=args.compress,
                           compression_method=args.compression_method,
                           pcm_format=args.format) as writer:
        for utt_id, (rate,
                     array) in kaldiio.ReadHelper(args.rspecifier,
                                                  args.segments):
            if args.filetype == 'mat':
                # Kaldi-matrix doesn't support integer
                array = array.astype(numpy.float32)

            if array.ndim == 1:
                # (Time) -> (Time, Channel)
                array = array[:, None]

            if args.normalize is not None and args.normalize != 1:
                array = array.astype(numpy.float32)
                array = array / (1 << (args.normalize - 1))

            if preprocessing is not None:
                orgtype = array.dtype
                out = preprocessing(array, uttid_list=utt_id)
                out = out.astype(orgtype)

                if args.keep_length:
                    if len(out) > len(array):
                        out = numpy.pad(out, [(0, len(out) - len(array))] +
                                        [(0, 0) for _ in range(out.ndim - 1)],
                                        mode='constant')
                    elif len(out) < len(array):
                        # The length can be changed by stft, for example.
                        out = out[:len(out)]

                array = out

            # shape = (Time, Channel)
            if args.filetype in ['sound.hdf5', 'sound']:
                # Write Tuple[int, numpy.ndarray] (scipy style)
                writer[utt_id] = (rate, array)
            else:
                writer[utt_id] = array
示例#29
0
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    aggregate_stats_dirs(**kwargs)
示例#30
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    if args.verbose > 0:
        logging.basicConfig(level=logging.INFO, format=logfmt)
    else:
        logging.basicConfig(level=logging.WARN, format=logfmt)
    logging.info(get_commandline_args())
    if len(args.reffiles) != len(args.enhfiles):
        raise RuntimeError(
            'The number of ref files are different '
            'from the enh files: {} != {}'.format(len(args.reffiles),
                                                  len(args.enhfiles)))
    if len(args.enhfiles) == 1:
        args.permutation = False

    # Read text files and created a mapping of key2filepath
    reffiles_dict = OrderedDict()  # Dict[str, Dict[str, str]]
    for ref in args.reffiles:
        d = OrderedDict()
        with open(ref, 'r') as f:
            for line in f:
                key, path = line.split(None, 1)
                d[key] = path.rstrip()
        reffiles_dict[ref] = d

    enhfiles_dict = OrderedDict()  # Dict[str, Dict[str, str]]
    for enh in args.enhfiles:
        d = OrderedDict()
        with open(enh, 'r') as f:
            for line in f:
                key, path = line.split(None, 1)
                d[key] = path.rstrip()
        enhfiles_dict[enh] = d

    if args.keylist is not None:
        with open(args.keylist, 'r') as f:
            keylist = [line.rstrip().split()[0] for line in f]
    else:
        keylist = list(reffiles_dict.values())[0]

    if len(keylist) == 0:
        raise RuntimeError('No keys are found')

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    evaltypes = []
    for evaltype in args.evaltypes:
        if evaltype == 'SDR':
            evaltypes += ['SDR', 'ISR', 'SIR', 'SAR']
        else:
            evaltypes.append(evaltype)

    # Open files in write mode
    writers = {k: open(os.path.join(args.outdir, k), 'w') for k in evaltypes}

    for key in keylist:
        # 1. Load ref files
        rate_prev = None

        ref_signals = []
        for listname, d in reffiles_dict.items():
            if key not in d:
                raise RuntimeError('{} doesn\'t exist in {}'
                                   .format(key, listname))
            filepath = d[key]
            signal, rate = soundfile.read(filepath, dtype=np.int16)
            if signal.ndim == 1:
                # (Nframe) -> (Nframe, 1)
                signal = signal[:, None]
            ref_signals.append(signal)
            if rate_prev is not None and rate != rate_prev:
                raise RuntimeError('Sampling rates mismatch')
            rate_prev = rate

        # 2. Load enh files
        enh_signals = []
        for listname, d in enhfiles_dict.items():
            if key not in d:
                raise RuntimeError('{} doesn\'t exist in {}'
                                   .format(key, listname))
            filepath = d[key]
            signal, rate = soundfile.read(filepath, dtype=np.int16)
            if signal.ndim == 1:
                # (Nframe) -> (Nframe, 1)
                signal = signal[:, None]
            enh_signals.append(signal)
            if rate_prev is not None and rate != rate_prev:
                raise RuntimeError('Sampling rates mismatch')
            rate_prev = rate

        for signal in ref_signals + enh_signals:
            if signal.shape[1] != ref_signals[0].shape[1]:
                raise RuntimeError('The number of channels mismatch')

        # 3. Zero padding to adjust the length to the maximum length in inputs
        ml = max(len(s) for s in ref_signals + enh_signals)
        ref_signals = [np.pad(s, [(0, ml - len(s)), (0, 0)], mode='constant')
                       if len(s) < ml else s for s in ref_signals]

        enh_signals = [np.pad(s, [(0, ml - len(s)), (0, 0)], mode='constant')
                       if len(s) < ml else s for s in enh_signals]

        # ref_signals, enh_signals: (Nsrc, Nframe, Nmic)
        ref_signals = np.stack(ref_signals, axis=0)
        enh_signals = np.stack(enh_signals, axis=0)

        # 4. Evaluates
        for evaltype in args.evaltypes:
            if evaltype == 'SDR':
                (sdr, isr, sir, sar, perm) = \
                    museval.metrics.bss_eval(
                        ref_signals, enh_signals,
                        window=np.inf, hop=np.inf,
                        compute_permutation=args.permutation,
                        filters_len=512,
                        framewise_filters=args.bss_eval_version == 'v3',
                        bsseval_sources_version=not args.bss_eval_images)

                # sdr: (Nsrc, Nframe)
                writers['SDR'].write(
                    '{} {}\n'.format(key, ' '.join(map(str, sdr[:, 0]))))
                writers['ISR'].write(
                    '{} {}\n'.format(key, ' '.join(map(str, isr[:, 0]))))
                writers['SIR'].write(
                    '{} {}\n'.format(key, ' '.join(map(str, sir[:, 0]))))
                writers['SAR'].write(
                    '{} {}\n'.format(key, ' '.join(map(str, sar[:, 0]))))

            elif evaltype == 'STOI':
                stoi, perm = eval_STOI(ref_signals, enh_signals, rate,
                                       extended=False,
                                       compute_permutation=args.permutation)
                writers['STOI'].write(
                    '{} {}\n'.format(key, ' '.join(map(str, stoi))))

            elif evaltype == 'ESTOI':
                estoi, perm = eval_STOI(ref_signals, enh_signals, rate,
                                        extended=True,
                                        compute_permutation=args.permutation)
                writers['ESTOI'].write(
                    '{} {}\n'.format(key, ' '.join(map(str, estoi))))

            elif evaltype == 'PESQ':
                pesq, perm = eval_PESQ(ref_signals, enh_signals, rate,
                                       compute_permutation=args.permutation)
                writers['PESQ'].write(
                    '{} {}\n'.format(key, ' '.join(map(str, pesq))))
            else:
                # Cannot reach
                raise RuntimeError