示例#1
0
def run():
    #mix_input = WaveReader(args.input, sample_rate=args.fs)
    mix_input = WaveReader("/export/home/clx214/gm/ntu_project/SpEx_SincNetAuxCNNEncoder_MultiOriEncoder_share_min_2spk/data/wsj0_2mix/tt/mix.scp", sample_rate=8000)
    aux_input = WaveReader("/export/home/clx214/gm/ntu_project/SpEx_SincNetAuxCNNEncoder_MultiOriEncoder_share_min_2spk/data/wsj0_2mix/tt/aux.scp", sample_rate=8000)
    computer = NnetComputer("/export/home/clx214/gm/ntu_project/SpEx_SincNetAuxCNNEncoder_MultiOriEncoder_share_min_2spk/exp_epoch114/conv_tasnet/conv-net", 0)
    #cmvn = np.load("/export/home/clx214/gm/ntu_project/SpEx2/data/tr_cmvn.npz")
    #mean_val = cmvn['mean_inputs']
    #std_val = cmvn['stddev_inputs']
    for key, mix_samps in mix_input:
        #print(key)
        #print(mix_samps)
        #spk_key = "spk_" + key.split('_')[-1][0:3]
        #aux_mfcc = read_mat(aux_input.index_dict[key.split('_')[-1]])
        #aux_mfcc = (aux_mfcc - mean_val) / (std_val + 1e-8)
        #aux_samps = read_vec_flt(aux_input.index_dict[spk_key])
        aux_samps = aux_input[key]
        logger.info("Compute on utterance {}...".format(key))
        spks = computer.compute(mix_samps, aux_samps, len(aux_samps))
        norm = np.linalg.norm(mix_samps, np.inf)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))
            write_wav(
                os.path.join("/export/home/clx214/gm/ntu_project/SpEx_SincNetAuxCNNEncoder_MultiOriEncoder_share_min_2spk/rec/", "spk{}/{}.wav".format(
                    idx + 1, key)),
                samps,
                fs=8000)
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
示例#2
0
def run(args):
    single_speaker = len(args.sep_scp.split(",")) == 1
    reporter = Report(args.spk2gender)

    if single_speaker:
        sep_reader = WaveReader(args.sep_scp)
        ref_reader = WaveReader(args.ref_scp)
        for key, sep in tqdm(sep_reader):
            ref = ref_reader[key]
            if sep.size != ref.size:
                end = min(sep.size, ref.size)
                sep = sep[:end]
                ref = ref[:end]
            snr = si_snr(sep, ref)
            reporter.add(key, snr)
    else:
        sep_reader = SpeakersReader(args.sep_scp)
        ref_reader = SpeakersReader(args.ref_scp)
        for key, sep_list in tqdm(sep_reader):
            ref_list = ref_reader[key]
            if sep_list[0].size != ref_list[0].size:
                end = min(sep_list[0].size, ref_list[0].size)
                sep_list = [s[:end] for s in sep_list]
                ref_list = [s[:end] for s in ref_list]
            snr = permute_si_snr(sep_list, ref_list)
            reporter.add(key, snr)
    reporter.report()
示例#3
0
def run(args):
    print("Working on folder {}".format(args.sep_scp))
    #get all scp files from separation folder
    folder = sorted(glob(args.sep_scp+'/*.scp', recursive=False)) # ndarray of names of all samples
    sep_scp = ""
    for scp in folder: # build string from array of scp files
        sep_scp += scp + "," 
    sep_scp = sep_scp[:-1] #remove last comma
    
    single_speaker = len(sep_scp.split(",")) == 1
    reporter = Report(args.spk2gender, outputDir=args.sep_scp)

    if single_speaker:
        sep_reader = WaveReader(sep_scp)
        ref_reader = WaveReader(args.ref_scp)
        for key, sep in tqdm(sep_reader):
            ref = ref_reader[key]
            if sep.size != ref.size:
                end = min(sep.size, ref.size)
                sep = sep[:end]
                ref = ref[:end]
            snr = si_snr(sep, ref)
            reporter.add(key, snr)
    else:
        sep_reader = SpeakersReader(sep_scp)
        ref_reader = SpeakersReader(args.ref_scp)

        for key, sep_list in tqdm(sep_reader):
            ref_list = ref_reader[key]
            zero_ref_list = ref_reader[key]
            if args.mixofmix != 0:
                if len(ref_list) > len(sep_list):
                    raise RuntimeError("There are more references then separs")
                #create zero references
                for i in range(len(sep_list) - len(ref_list)):
                    zero_ref_list.append(np.zeros_like(ref_list[0])+0.0001)
            #Cut lengths
            if sep_list[0].size != ref_list[0].size:
                end = min(sep_list[0].size, ref_list[0].size)
                sep_list = [s[:end] for s in sep_list]
                ref_list = [s[:end] for s in ref_list]
                zero_ref_list = [s[:end] for s in zero_ref_list]
            if args.mixofmix != 0: # get right outputs combination
                right_sep_list = permute_si_snr_mix_of_mix(sep_list, zero_ref_list)  
            else: right_sep_list = sep_list #compatibility
            #PIT
            snr = permute_si_snr(right_sep_list[:len(ref_list)], ref_list)
            reporter.add(key, snr)
    reporter.report()
示例#4
0
 def __init__(self, scps):
     split_scps = scps.split(",")
     if len(split_scps) == 1:
         raise RuntimeError(
             "Construct SpeakersReader need more than one script, got {}".
             format(scps))
     self.readers = [WaveReader(scp) for scp in split_scps]
示例#5
0
def run(args):
    os.mkdir(args.dump_dir)
    mix_input = WaveReader(args.input, sample_rate=args.fs)
    computer = NnetComputer(args.checkpoint, args.gpu)
    cpyModelInfo(args.checkpoint, args.dump_dir)
    lenGen = 0
    for key, mix_samps in tqdm(mix_input):
        if logging is True: logger.info("Compute on utterance {}...".format(key))
        spks = computer.compute(mix_samps)
        norm = np.linalg.norm(mix_samps, np.inf)
        lenGen = len(spks)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))

            write_wav(
                os.path.join(args.dump_dir, "spk{}/{}.wav".format(
                    idx + 1, key)),
                samps,
                fs=args.fs)
        if args.plot != 0: plotOutputs(os.path.join(args.dump_dir, "plot_spk/{}.png".format(key)), spks)
    #generate SCP files
    for idx in range(lenGen):
        generateFile(os.path.join(args.dump_dir, "spk{}".format(
                    idx + 1)), os.path.join(args.dump_dir, "spk{}.scp".format(
                    idx + 1)))
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
示例#6
0
def run():
    mix_input = WaveReader("data/wsj0_2mix/tt/mix.scp", sample_rate=8000)
    aux_input = WaveReader("data/wsj0_2mix/tt/aux.scp", sample_rate=8000)
    computer = NnetComputer("exp_epoch114/conv_tasnet/conv-net", 3)
    for key, mix_samps in mix_input:
        aux_samps = aux_input[key]
        logger.info("Compute on utterance {}...".format(key))
        spks = computer.compute(mix_samps, aux_samps, len(aux_samps))
        norm = np.linalg.norm(mix_samps, np.inf)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))
            write_wav(os.path.join("rec/", "{}.wav".format(key)),
                      samps,
                      fs=8000)
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
示例#7
0
def run(args):
    mix_input = WaveReader(args.input, sample_rate=args.fs)
    computer = NnetComputer(args.checkpoint, args.gpu)
    for key, mix_samps in mix_input:
        logger.info("Compute on utterance {}...".format(key))
        spks = computer.compute(mix_samps)
        norm = np.linalg.norm(mix_samps, np.inf)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))
            write_wav(
                os.path.join(args.dump_dir, "spk{}/{}.wav".format(
                    idx + 1, key)),
                samps,
                fs=args.fs)
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
示例#8
0
def run2(args):
    os.mkdir(args.dump_dir)
    mix_input = WaveReader(args.input, sample_rate=args.fs)
    computer = NnetComputer(args.checkpoint, args.gpu)
    cpyModelInfo(args.checkpoint, args.dump_dir)
    lenGen = 0
    for key, mix_samps in tqdm(mix_input):
        if logging is True: logger.info("Compute on utterance {}...".format(key))
        
        firstSpeaker = key[:3]
        secondSpeaker = key.split('_')[2][:3]

        aviableKeys = [key for key in mix_input.index_keys if not re.search("((("+firstSpeaker+")|("+secondSpeaker+"))?.*_.*_(("+firstSpeaker+")|("+secondSpeaker+")).*)|((("+firstSpeaker+")|("+secondSpeaker+")).*_.*_(("+firstSpeaker+")|("+secondSpeaker+"))?.*)", key)] 
            
        secondKey = aviableKeys[random.randint(0,len(aviableKeys)-1)]
        secondMix = mix_input[secondKey]

        if len(mix_samps) < len(secondMix):
                secondMix = secondMix[:len(mix_samps)]
        else:
            secondMix = np.pad(secondMix, (0,len(mix_samps)-len(secondMix)), "constant",constant_values=(0,0))

        mixofmix = mix_samps+secondMix

        spks = computer.compute(mixofmix)
        norm = np.linalg.norm(mixofmix, np.inf)
        lenGen = len(spks)
        for idx, samps in enumerate(spks):
            samps = samps[:mixofmix.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))

            write_wav(
                os.path.join(args.dump_dir, "spk{}/{}.wav".format(
                    idx + 1, key)),
                samps,
                fs=args.fs)
        if args.plot != 0: plotOutputs(os.path.join(args.dump_dir, "plot_spk/{}.png".format(key)), spks)
    #generate SCP files
    for idx in range(lenGen):
        generateFile(os.path.join(args.dump_dir, "spk{}".format(
                    idx + 1)), os.path.join(args.dump_dir, "spk{}.scp".format(
                    idx + 1)))
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
示例#9
0
def run(args):
    if args.dump_dir:
        os.makedirs(args.dump_dir, exist_ok=True)
    wave_reader = WaveReader(args.wav_scp)
    with open(os.path.join(args.dump_dir, "emb.key"), "w") as emb:
        with open(args.csv, "r") as f:
            reader = csv.reader(f)
            for ids in tqdm.tqdm(reader):
                src_id, ref_id, itf_id = ids
                emb.write("{}\t{}\n".format("_".join(ids), ref_id))
                src = wave_reader[src_id]
                itf = wave_reader[itf_id]
                src, mix = mix_audio(src, itf)
                write_wav(
                    os.path.join(args.dump_dir,
                                 "src/{}.wav".format("_".join(ids))), src)
                write_wav(
                    os.path.join(args.dump_dir,
                                 "mix/{}.wav".format("_".join(ids))), mix)
示例#10
0
def separating(args):
    # model
    # network configure
    nnet_conf = {
        "L": args.L,
        "N": args.N,
        "X": args.X,
        "R": args.R,
        "B": args.B,
        "H": args.H,
        "P": args.P,
        "norm": args.norm,
        "num_spks": args.num_spks,
        "non_linear": args.non_linear,
        "causal": args.causal,
    }
    nnet = ConvTasNet(**nnet_conf)
    nnet.load_state_dict(flow.load(args.model_path))
    device = flow.device("cuda")
    nnet.to(device)

    with flow.no_grad():
        mix_input = WaveReader(args.input, sample_rate=args.fs)
        for key, mix_samps in mix_input:
            raw = flow.tensor(mix_samps, dtype=flow.float32, device=device)
            sps = nnet(raw)
            spks = [np.squeeze(s.detach().cpu().numpy()) for s in sps]
            norm = np.linalg.norm(mix_samps, np.inf)
            for idx, samps in enumerate(spks):
                samps = samps[:mix_samps.size]
                samps = samps * norm / np.max(np.abs(samps))
                write_wav(
                    os.path.join(args.dump_dir,
                                 "spk{}/{}".format(idx + 1, key)),
                    samps,
                    fs=args.fs,
                )
示例#11
0
def run(args):
    mix_input = WaveReader(args.input, sample_rate=args.fs, get_filepath=True)
    cpt_tag = os.path.basename(args.checkpoint)

    if args.online == 1:
        computer = NnetComputer(args.checkpoint, args.gpu, online=True)
    else:
        computer = NnetComputer(args.checkpoint, args.gpu)

    for key, mix_samps in mix_input:
        filepath_org = mix_samps[1]
        mix_samps = mix_samps[0]
        logger.info("Compute on utterance {}...".format(key))

        if args.online == 1:
            # re-initialize
            computer.init_online(args.checkpoint, args.gpu)

        spks = computer.compute(mix_samps)

        norm = np.linalg.norm(mix_samps, np.inf)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            if computer.nnet.causal:
                samps = samps * 1.00
                # samps = samps * 0.89 # consistently -1dB
                # samps = samps * 0.71 # consistently -3dB
                # samps = samps * 0.56 # consistently -3dB
                if np.max(np.abs(samps)) >= 1.0:

                    print("clipping!")

            else:
                if False:
                    samps = samps * norm / np.max(np.abs(samps))

            if idx == 0:
                fname = os.path.join(args.dump_dir,
                                     "{}.speech.wav".format(key))
                write_wav(fname, samps, fs=args.fs)

                if not computer.nnet.causal:
                    ip = np.dot(samps, mix_samps)
                    samps_scale_optimized = ip / np.dot(samps, samps) * samps
                    fname_scale_optimized = fname.replace(
                        cpt_tag, "%s_scale_optimized" % (cpt_tag))
                    write_wav(fname_scale_optimized,
                              samps_scale_optimized,
                              fs=args.fs)

                if args.online == 1:
                    fname_submit = os.path.join("%s" % (args.dump_dir),
                                                os.path.basename(filepath_org))
                    fname_submit = fname_submit.replace(
                        "%s_online" % (cpt_tag),
                        "%s_online/submit" % (cpt_tag))
                    write_wav(fname_submit, samps * 0.89,
                              fs=args.fs)  # -1dB (to avoid clipping)

                    #import pdb; pdb.set_trace()

            elif idx == 1:
                fname = os.path.join(args.dump_dir, "{}.noise.wav".format(key))
                write_wav(fname, samps, fs=args.fs)

    logger.info("Compute over {:d} utterances".format(len(mix_input)))