Ejemplo n.º 1
0
def run(args):
    os.mkdir(args.dump_dir)
    mix_input = WaveReader(args.input, sample_rate=args.fs)
    computer = NnetComputer(args.checkpoint, args.gpu)
    cpyModelInfo(args.checkpoint, args.dump_dir)
    lenGen = 0
    for key, mix_samps in tqdm(mix_input):
        if logging is True: logger.info("Compute on utterance {}...".format(key))
        spks = computer.compute(mix_samps)
        norm = np.linalg.norm(mix_samps, np.inf)
        lenGen = len(spks)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))

            write_wav(
                os.path.join(args.dump_dir, "spk{}/{}.wav".format(
                    idx + 1, key)),
                samps,
                fs=args.fs)
        if args.plot != 0: plotOutputs(os.path.join(args.dump_dir, "plot_spk/{}.png".format(key)), spks)
    #generate SCP files
    for idx in range(lenGen):
        generateFile(os.path.join(args.dump_dir, "spk{}".format(
                    idx + 1)), os.path.join(args.dump_dir, "spk{}.scp".format(
                    idx + 1)))
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
Ejemplo n.º 2
0
def run():
    #mix_input = WaveReader(args.input, sample_rate=args.fs)
    mix_input = WaveReader("/export/home/clx214/gm/ntu_project/SpEx_SincNetAuxCNNEncoder_MultiOriEncoder_share_min_2spk/data/wsj0_2mix/tt/mix.scp", sample_rate=8000)
    aux_input = WaveReader("/export/home/clx214/gm/ntu_project/SpEx_SincNetAuxCNNEncoder_MultiOriEncoder_share_min_2spk/data/wsj0_2mix/tt/aux.scp", sample_rate=8000)
    computer = NnetComputer("/export/home/clx214/gm/ntu_project/SpEx_SincNetAuxCNNEncoder_MultiOriEncoder_share_min_2spk/exp_epoch114/conv_tasnet/conv-net", 0)
    #cmvn = np.load("/export/home/clx214/gm/ntu_project/SpEx2/data/tr_cmvn.npz")
    #mean_val = cmvn['mean_inputs']
    #std_val = cmvn['stddev_inputs']
    for key, mix_samps in mix_input:
        #print(key)
        #print(mix_samps)
        #spk_key = "spk_" + key.split('_')[-1][0:3]
        #aux_mfcc = read_mat(aux_input.index_dict[key.split('_')[-1]])
        #aux_mfcc = (aux_mfcc - mean_val) / (std_val + 1e-8)
        #aux_samps = read_vec_flt(aux_input.index_dict[spk_key])
        aux_samps = aux_input[key]
        logger.info("Compute on utterance {}...".format(key))
        spks = computer.compute(mix_samps, aux_samps, len(aux_samps))
        norm = np.linalg.norm(mix_samps, np.inf)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))
            write_wav(
                os.path.join("/export/home/clx214/gm/ntu_project/SpEx_SincNetAuxCNNEncoder_MultiOriEncoder_share_min_2spk/rec/", "spk{}/{}.wav".format(
                    idx + 1, key)),
                samps,
                fs=8000)
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
Ejemplo n.º 3
0
def run(args):
    start = time.time()
    logger = get_logger(
            os.path.join(args.checkpoint, 'separate.log'), file=True)
    
    dataset = Dataset(mix_scp=args.mix_scp, ref_scp=args.ref_scp, aux_scp=args.aux_scp)
    
    # Load model
    nnet_conf = load_json(args.checkpoint, "mdl.json")
    nnet = ConvTasNet(**nnet_conf)
    cpt_fname = os.path.join(args.checkpoint, "best.pt.tar")
    cpt = th.load(cpt_fname, map_location="cpu")
    nnet.load_state_dict(cpt["model_state_dict"]) 
    logger.info("Load checkpoint from {}, epoch {:d}".format(
        cpt_fname, cpt["epoch"]))
    
    device = th.device(
        "cuda:{}".format(args.gpuid)) if args.gpuid >= 0 else th.device("cpu")
    nnet = nnet.to(device) if args.gpuid >= 0 else nnet
    nnet.eval()
    
    with th.no_grad():
        total_cnt = 0
        for i, data in enumerate(dataset):    
            mix = th.tensor(data['mix'], dtype=th.float32, device=device)
            aux = th.tensor(data['aux'], dtype=th.float32, device=device) 
            aux_len = th.tensor(data['aux_len'], dtype=th.float32, device=device)
            key = data['key']
            
            if args.gpuid >= 0:
                mix = mix.cuda()
                aux = aux.cuda()
                aux_len = aux_len.cuda()
                
            # Forward            
            ests = nnet(mix, aux, aux_len)
            ests = ests.cpu().numpy()
            norm = np.linalg.norm(mix.cpu().numpy(), np.inf)
            ests = ests[:mix.shape[-1]]
            
            # for each utts
            logger.info("Separate Utt{:d}".format(total_cnt + 1))
            # norm
            ests = ests * norm / np.max(np.abs(ests))
            write_wav(os.path.join(args.dump_dir, key),
                      ests,
                      fs=args.fs)
            total_cnt += 1   
            break
    
    end = time.time()
    logger.info('Utt={:d} | Time Elapsed: {:.1f}s'.format(total_cnt, end-start))
Ejemplo n.º 4
0
def run(args):
    mix_input = WaveReader(args.input, sample_rate=args.fs)
    computer = NnetComputer(args.checkpoint, args.gpu)
    for key, mix_samps in mix_input:
        logger.info("Compute on utterance {}...".format(key))
        spks = computer.compute(mix_samps)
        norm = np.linalg.norm(mix_samps, np.inf)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))
            write_wav(
                os.path.join(args.dump_dir, "spk{}/{}.wav".format(
                    idx + 1, key)),
                samps,
                fs=args.fs)
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
Ejemplo n.º 5
0
def run():
    mix_input = WaveReader("data/wsj0_2mix/tt/mix.scp", sample_rate=8000)
    aux_input = WaveReader("data/wsj0_2mix/tt/aux.scp", sample_rate=8000)
    computer = NnetComputer("exp_epoch114/conv_tasnet/conv-net", 3)
    for key, mix_samps in mix_input:
        aux_samps = aux_input[key]
        logger.info("Compute on utterance {}...".format(key))
        spks = computer.compute(mix_samps, aux_samps, len(aux_samps))
        norm = np.linalg.norm(mix_samps, np.inf)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))
            write_wav(os.path.join("rec/", "{}.wav".format(key)),
                      samps,
                      fs=8000)
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
Ejemplo n.º 6
0
def run2(args):
    os.mkdir(args.dump_dir)
    mix_input = WaveReader(args.input, sample_rate=args.fs)
    computer = NnetComputer(args.checkpoint, args.gpu)
    cpyModelInfo(args.checkpoint, args.dump_dir)
    lenGen = 0
    for key, mix_samps in tqdm(mix_input):
        if logging is True: logger.info("Compute on utterance {}...".format(key))
        
        firstSpeaker = key[:3]
        secondSpeaker = key.split('_')[2][:3]

        aviableKeys = [key for key in mix_input.index_keys if not re.search("((("+firstSpeaker+")|("+secondSpeaker+"))?.*_.*_(("+firstSpeaker+")|("+secondSpeaker+")).*)|((("+firstSpeaker+")|("+secondSpeaker+")).*_.*_(("+firstSpeaker+")|("+secondSpeaker+"))?.*)", key)] 
            
        secondKey = aviableKeys[random.randint(0,len(aviableKeys)-1)]
        secondMix = mix_input[secondKey]

        if len(mix_samps) < len(secondMix):
                secondMix = secondMix[:len(mix_samps)]
        else:
            secondMix = np.pad(secondMix, (0,len(mix_samps)-len(secondMix)), "constant",constant_values=(0,0))

        mixofmix = mix_samps+secondMix

        spks = computer.compute(mixofmix)
        norm = np.linalg.norm(mixofmix, np.inf)
        lenGen = len(spks)
        for idx, samps in enumerate(spks):
            samps = samps[:mixofmix.size]
            # norm
            samps = samps * norm / np.max(np.abs(samps))

            write_wav(
                os.path.join(args.dump_dir, "spk{}/{}.wav".format(
                    idx + 1, key)),
                samps,
                fs=args.fs)
        if args.plot != 0: plotOutputs(os.path.join(args.dump_dir, "plot_spk/{}.png".format(key)), spks)
    #generate SCP files
    for idx in range(lenGen):
        generateFile(os.path.join(args.dump_dir, "spk{}".format(
                    idx + 1)), os.path.join(args.dump_dir, "spk{}.scp".format(
                    idx + 1)))
    logger.info("Compute over {:d} utterances".format(len(mix_input)))
Ejemplo n.º 7
0
def run(args):
    if args.dump_dir:
        os.makedirs(args.dump_dir, exist_ok=True)
    wave_reader = WaveReader(args.wav_scp)
    with open(os.path.join(args.dump_dir, "emb.key"), "w") as emb:
        with open(args.csv, "r") as f:
            reader = csv.reader(f)
            for ids in tqdm.tqdm(reader):
                src_id, ref_id, itf_id = ids
                emb.write("{}\t{}\n".format("_".join(ids), ref_id))
                src = wave_reader[src_id]
                itf = wave_reader[itf_id]
                src, mix = mix_audio(src, itf)
                write_wav(
                    os.path.join(args.dump_dir,
                                 "src/{}.wav".format("_".join(ids))), src)
                write_wav(
                    os.path.join(args.dump_dir,
                                 "mix/{}.wav".format("_".join(ids))), mix)
Ejemplo n.º 8
0
def separating(args):
    # model
    # network configure
    nnet_conf = {
        "L": args.L,
        "N": args.N,
        "X": args.X,
        "R": args.R,
        "B": args.B,
        "H": args.H,
        "P": args.P,
        "norm": args.norm,
        "num_spks": args.num_spks,
        "non_linear": args.non_linear,
        "causal": args.causal,
    }
    nnet = ConvTasNet(**nnet_conf)
    nnet.load_state_dict(flow.load(args.model_path))
    device = flow.device("cuda")
    nnet.to(device)

    with flow.no_grad():
        mix_input = WaveReader(args.input, sample_rate=args.fs)
        for key, mix_samps in mix_input:
            raw = flow.tensor(mix_samps, dtype=flow.float32, device=device)
            sps = nnet(raw)
            spks = [np.squeeze(s.detach().cpu().numpy()) for s in sps]
            norm = np.linalg.norm(mix_samps, np.inf)
            for idx, samps in enumerate(spks):
                samps = samps[:mix_samps.size]
                samps = samps * norm / np.max(np.abs(samps))
                write_wav(
                    os.path.join(args.dump_dir,
                                 "spk{}/{}".format(idx + 1, key)),
                    samps,
                    fs=args.fs,
                )
Ejemplo n.º 9
0
def run(args):
    cpt_tag = os.path.basename(args.checkpoint)

    if args.online == 1:
        computer = NnetComputer(args.checkpoint, args.gpu, online=True)
    else:
        computer = NnetComputer(args.checkpoint, args.gpu)

    mix_input = [(os.path.basename(f), (read_wav(f), 'foo'))
                 for f in args.input]
    for key, mix_samps in mix_input:
        filepath_org = mix_samps[1]
        mix_samps = mix_samps[0]
        logger.info("Compute on utterance {}...".format(key))

        if args.online == 1:
            # re-initialize
            computer.init_online(args.checkpoint, args.gpu)

        spks = computer.compute(mix_samps)

        norm = np.linalg.norm(mix_samps, np.inf)
        for idx, samps in enumerate(spks):
            samps = samps[:mix_samps.size]
            # norm
            if computer.nnet.causal:
                samps = samps * 1.00
                # samps = samps * 0.89 # consistently -1dB
                # samps = samps * 0.71 # consistently -3dB
                # samps = samps * 0.56 # consistently -3dB
                if np.max(np.abs(samps)) >= 1.0:

                    print("clipping!")

            else:
                if False:
                    samps = samps * norm / np.max(np.abs(samps))

            if idx == 0:
                fname = os.path.join(args.dump_dir,
                                     "{}.speech.wav".format(key))
                write_wav(fname, samps, fs=args.fs)

                if not computer.nnet.causal:
                    ip = np.dot(samps, mix_samps)
                    samps_scale_optimized = ip / np.dot(samps, samps) * samps
                    fname_scale_optimized = fname.replace(
                        cpt_tag, "%s_scale_optimized" % (cpt_tag))
                    write_wav(fname_scale_optimized,
                              samps_scale_optimized,
                              fs=args.fs)

                if args.online == 1:
                    fname_submit = os.path.join("%s" % (args.dump_dir),
                                                os.path.basename(filepath_org))
                    fname_submit = fname_submit.replace(
                        "%s_online" % (cpt_tag),
                        "%s_online/submit" % (cpt_tag))
                    write_wav(fname_submit, samps * 0.89,
                              fs=args.fs)  # -1dB (to avoid clipping)

                    #import pdb; pdb.set_trace()

            elif idx == 1:
                fname = os.path.join(args.dump_dir, "{}.noise.wav".format(key))
                write_wav(fname, samps, fs=args.fs)

    logger.info("Compute over {:d} utterances".format(len(mix_input)))