def run(args):
    stft_kwargs = {
        "frame_length": args.frame_length,
        "frame_shift": args.frame_shift,
        "window": args.window,
        "center": args.center, # false to comparable with kaldi
        "transpose": False
    }
    spectrogram_reader = SpectrogramReader(args.wav_scp, **stft_kwargs)
    mask_reader = ArchieveReader(args.mask_scp)
    
    num_bins = nfft(args.frame_length) // 2 + 1
    beamformer = MvdrBeamformer(
        num_bins) if args.beamformer == "mvdr" else GevdBeamformer(num_bins)

    num_utts = 0
    for key, stft_mat in spectrogram_reader:
        num_utts += 1
        logger.info("Processing utterance {}...".format(key))
        speech_mask = mask_reader[key]
        stft_enh = beamformer.run(speech_mask, stft_mat)
        # do not normalize
        istft(
            os.path.join(args.dst_dir, '{}.wav'.format(key)), stft_enh,
            **stft_kwargs)
    logger.info("Processed {} utterances".format(num_utts))
Beispiel #2
0
def run(args):
    num_bins, config_dict = parse_yaml(args.config)
    dataloader_conf = config_dict["dataloader"]
    spectrogram_conf = config_dict["spectrogram_reader"]
    # Load cmvn
    dict_mvn = dataloader_conf["mvn_dict"]
    if dict_mvn:
        if not os.path.exists(dict_mvn):
            raise FileNotFoundError("Could not find mvn files")
        with open(dict_mvn, "rb") as f:
            dict_mvn = pickle.load(f)
    # default: True
    apply_log = dataloader_conf[
        "apply_log"] if "apply_log" in dataloader_conf else True

    dcnet = PITNet(num_bins, **config_dict["model"])

    frame_length = spectrogram_conf["frame_length"]
    frame_shift = spectrogram_conf["frame_shift"]
    window = spectrogram_conf["window"]

    separator = Separator(dcnet, args.state_dict, cuda=args.cuda)

    utt_dict = parse_scps(args.wave_scp)
    num_utts = 0
    for key, utt in utt_dict.items():
        try:
            samps, stft_mat = stft(utt,
                                   frame_length=frame_length,
                                   frame_shift=frame_shift,
                                   window=window,
                                   center=True,
                                   return_samps=True)
        except FileNotFoundError:
            print("Skip utterance {}... not found".format(key))
            continue
        print("Processing utterance {}".format(key))
        num_utts += 1
        norm = np.linalg.norm(samps, np.inf)
        spk_mask, spk_spectrogram = separator.seperate(stft_mat,
                                                       cmvn=dict_mvn,
                                                       apply_log=apply_log)

        for index, stft_mat in enumerate(spk_spectrogram):
            istft(os.path.join(args.dump_dir,
                               '{}.spk{}.wav'.format(key, index + 1)),
                  stft_mat,
                  frame_length=frame_length,
                  frame_shift=frame_shift,
                  window=window,
                  center=True,
                  norm=norm,
                  fs=8000,
                  nsamps=samps.size)
            if args.dump_mask:
                sio.savemat(
                    os.path.join(args.dump_dir,
                                 '{}.spk{}.mat'.format(key, index + 1)),
                    {"mask": spk_mask[index]})
    print("Processed {} utterance!".format(num_utts))
Beispiel #3
0
def run(args):
    # return complex result
    reader_kwargs = {
        "frame_length": args.frame_length,
        "frame_shift": args.frame_shift,
        "window": args.window,
        "center": True
    }
    print(
        "Using {} Mask".format("Ratio" if not args.psm else "Phase Sensitive"))
    mixture_reader = SpectrogramReader(args.mix_scp,
                                       **reader_kwargs,
                                       return_samps=True)
    targets_reader = [
        SpectrogramReader(scp, **reader_kwargs) for scp in args.ref_scp
    ]
    num_utts = 0
    for key, packed in mixture_reader:
        samps, mixture = packed
        norm = np.linalg.norm(samps, np.inf)
        skip = False
        for reader in targets_reader:
            if key not in reader:
                print("Skip utterance {}, missing targets".format(key))
                skip = True
                break
        if skip:
            continue
        num_utts += 1
        if not num_utts % 1000:
            print("Processed {} utterance...".format(num_utts))
        targets_list = [reader[key] for reader in targets_reader]
        spk_masks = compute_mask(mixture, targets_list, args.psm)
        for index, mask in enumerate(spk_masks):
            istft(os.path.join(args.dump_dir,
                               '{}.spk{}.wav'.format(key, index + 1)),
                  mixture * mask,
                  **reader_kwargs,
                  norm=norm,
                  fs=8000,
                  nsamps=samps.size)
    print("Processed {} utterance!".format(num_utts))
Beispiel #4
0
def run(args):
    num_bins, config_dict = parse_yaml(args.config)
    # Load cmvn
    dict_mvn = config_dict["dataloader"]["mvn_dict"]
    if dict_mvn:
        if not os.path.exists(dict_mvn):
            raise FileNotFoundError("Could not find mvn files")
        with open(dict_mvn, "rb") as f:
            dict_mvn = pickle.load(f)

    dcnet = DCNet(num_bins, **config_dict["dcnet"])

    frame_length = config_dict["spectrogram_reader"]["frame_length"]
    frame_shift = config_dict["spectrogram_reader"]["frame_shift"]
    window = config_dict["spectrogram_reader"]["window"]

    cluster = DeepCluster(
        dcnet,
        args.dcnet_state,
        args.num_spks,
        pca=args.dump_pca,
        cuda=args.cuda)

    utt_dict = parse_scps(args.wave_scp)
    num_utts = 0
    for key, utt in utt_dict.items():
        try:
            samps, stft_mat = stft(
                utt,
                frame_length=frame_length,
                frame_shift=frame_shift,
                window=window,
                center=True,
                return_samps=True)
        except FileNotFoundError:
            print("Skip utterance {}... not found".format(key))
            continue
        print("Processing utterance {}".format(key))
        num_utts += 1
        norm = np.linalg.norm(samps, np.inf)
        pca_mat, spk_mask, spk_spectrogram = cluster.seperate(
            stft_mat, cmvn=dict_mvn)

        for index, stft_mat in enumerate(spk_spectrogram):
            istft(
                os.path.join(args.dump_dir, '{}.spk{}.wav'.format(
                    key, index + 1)),
                stft_mat,
                frame_length=frame_length,
                frame_shift=frame_shift,
                window=window,
                center=True,
                norm=norm,
                fs=8000,
                nsamps=samps.size)
            if args.dump_mask:
                sio.savemat(
                    os.path.join(args.dump_dir, '{}.spk{}.mat'.format(
                        key, index + 1)), {"mask": spk_mask[index]})
        if args.dump_pca:
            sio.savemat(
                os.path.join(args.dump_dir, '{}.mat'.format(key)),
                {"pca_matrix": pca_mat})
    print("Processed {} utterance!".format(num_utts))
def run(args):
    num_bins, config_dict = parse_yaml(args.config)
    # Load cmvn
    dict_mvn = config_dict["dataloader"]["mvn_dict"]
    if dict_mvn:
        if not os.path.exists(dict_mvn):
            raise FileNotFoundError("Could not find mvn files")
        with open(dict_mvn, "rb") as f:
            dict_mvn = pickle.load(f)

    dcnet = DCNet(num_bins, **config_dict["dcnet"])

    frame_length = config_dict["spectrogram_reader"]["frame_length"]
    frame_shift = config_dict["spectrogram_reader"]["frame_shift"]
    window = config_dict["spectrogram_reader"]["window"]

    cluster = DeepCluster(dcnet,
                          args.dcnet_state,
                          args.num_spks,
                          pca=args.dump_pca,
                          cuda=args.cuda)

    utt_dict = parse_scps(args.wave_scp)
    num_utts = 0
    for key, utt in utt_dict.items():
        try:
            samps, stft_mat = stft(utt,
                                   frame_length=frame_length,
                                   frame_shift=frame_shift,
                                   window=window,
                                   center=True,
                                   return_samps=True)
        except FileNotFoundError:
            print("Skip utterance {}... not found".format(key))
            continue
        print("Processing utterance {}".format(key))
        num_utts += 1
        norm = np.linalg.norm(samps, np.inf)
        pca_mat, spk_mask, spk_spectrogram = cluster.seperate(stft_mat,
                                                              cmvn=dict_mvn)

        for index, stft_mat in enumerate(spk_spectrogram):
            istft(os.path.join(args.dump_dir,
                               '{}.spk{}.wav'.format(key, index + 1)),
                  stft_mat,
                  frame_length=frame_length,
                  frame_shift=frame_shift,
                  window=window,
                  center=True,
                  norm=norm,
                  fs=8000,
                  nsamps=samps.size)
            if args.dump_mask:
                sio.savemat(
                    os.path.join(args.dump_dir,
                                 '{}.spk{}.mat'.format(key, index + 1)),
                    {"mask": spk_mask[index]})
        if args.dump_pca:
            sio.savemat(os.path.join(args.dump_dir, '{}.mat'.format(key)),
                        {"pca_matrix": pca_mat})
    print("Processed {} utterance!".format(num_utts))