Ejemplo n.º 1
 def _load_nnet(self, cpt_dir):
     nnet_conf = load_json(cpt_dir, "mdl.json")
     nnet = ConvTasNet(**nnet_conf)
     cpt_fname = os.path.join(cpt_dir, "best.pt.tar")
     cpt = th.load(cpt_fname, map_location="cpu")
     logger.info("Load checkpoint from {}, epoch {:d}".format(
         cpt_fname, cpt["epoch"]))
     return nnet
Ejemplo n.º 2
 def _load_nnet(self, cpt_dir, online=False, init_dump=False):
     nnet_conf = load_json(cpt_dir, "mdl.json")
     nnet = TCNSENet(**nnet_conf, online=online, init_dump=init_dump)
     cpt_fname = os.path.join(cpt_dir, "best.pt.tar")
     cpt = th.load(cpt_fname, map_location="cpu")
     logger.info("Load checkpoint from {}, epoch {:d}".format(
         cpt_fname, cpt["epoch"]))
     return nnet
Ejemplo n.º 3
def run(args):
    start = time.time()
    logger = get_logger(
            os.path.join(args.checkpoint, 'separate.log'), file=True)
    dataset = Dataset(mix_scp=args.mix_scp, ref_scp=args.ref_scp, aux_scp=args.aux_scp)
    # Load model
    nnet_conf = load_json(args.checkpoint, "mdl.json")
    nnet = ConvTasNet(**nnet_conf)
    cpt_fname = os.path.join(args.checkpoint, "best.pt.tar")
    cpt = th.load(cpt_fname, map_location="cpu")
    logger.info("Load checkpoint from {}, epoch {:d}".format(
        cpt_fname, cpt["epoch"]))
    device = th.device(
        "cuda:{}".format(args.gpuid)) if args.gpuid >= 0 else th.device("cpu")
    nnet = nnet.to(device) if args.gpuid >= 0 else nnet
    with th.no_grad():
        total_cnt = 0
        for i, data in enumerate(dataset):    
            mix = th.tensor(data['mix'], dtype=th.float32, device=device)
            aux = th.tensor(data['aux'], dtype=th.float32, device=device) 
            aux_len = th.tensor(data['aux_len'], dtype=th.float32, device=device)
            key = data['key']
            if args.gpuid >= 0:
                mix = mix.cuda()
                aux = aux.cuda()
                aux_len = aux_len.cuda()
            # Forward            
            ests = nnet(mix, aux, aux_len)
            ests = ests.cpu().numpy()
            norm = np.linalg.norm(mix.cpu().numpy(), np.inf)
            ests = ests[:mix.shape[-1]]
            # for each utts
            logger.info("Separate Utt{:d}".format(total_cnt + 1))
            # norm
            ests = ests * norm / np.max(np.abs(ests))
            write_wav(os.path.join(args.dump_dir, key),
            total_cnt += 1   
    end = time.time()
    logger.info('Utt={:d} | Time Elapsed: {:.1f}s'.format(total_cnt, end-start))
Ejemplo n.º 4
def run(args):
    computer = NnetComputer(args.checkpoint, args.gpu)
    num_done = 0
    feats_conf = load_json(args.checkpoint, "feats.json")
    spectra = Processor(args.spectra, **feats_conf)
    spatial = ScriptReader(args.spatial) if args.spatial else None
    dump_dir = Path(args.dump_dir)
    dump_dir.mkdir(exist_ok=True, parents=True)
    for key, feats in spectra:
        logger.info("Compute on utterance {}...".format(key))
        if spatial:
            spa = spatial[key]
            feats = np.hstack([feats, spa])
        spk_masks = computer.compute(feats)
        for i, m in enumerate(spk_masks):
            (dump_dir / f"spk{i + 1:d}").mkdir(exist_ok=True)
            np.save(dump_dir / f"spk{i + 1:d}" / key, m)
        num_done += 1
    logger.info("Compute over {:d} utterances".format(num_done))
Ejemplo n.º 5
def run(args):
    computer = NnetComputer(args.checkpoint, args.gpu)
    num_done = 0
    feats_conf = load_json(args.checkpoint, "feats.json")
    spectra = Processor(args.spectra, **feats_conf)
    spatial = ScriptReader(args.spatial) if args.spatial else None

    for key, feats in spectra:
        logger.info("Compute on utterance {}...".format(key))
        if spatial:
            spa = spatial[key]
            feats = np.hstack([feats, spa])
        spk_masks = computer.compute(feats)
        for i, m in enumerate(spk_masks):
            fdir = os.path.join(args.dump_dir, "spk{:d}".format(i + 1))
            np.save(os.path.join(fdir, key), m)
        num_done += 1
    logger.info("Compute over {:d} utterances".format(num_done))
Ejemplo n.º 6
def evaluate(args):
    start = time.time()
    total_SISNR = 0
    total_SDR = 0
    total_cnt = 0

    # build the logger object
    logger = get_logger(os.path.join(args.checkpoint, 'eval.log'), file=True)

    # Load model
    nnet_conf = load_json(args.checkpoint, "mdl.json")
    nnet = ConvTasNet(**nnet_conf)
    cpt_fname = os.path.join(args.checkpoint, "best.pt.tar")
    cpt = th.load(cpt_fname, map_location="cpu")
    logger.info("Load checkpoint from {}, epoch {:d}".format(
        cpt_fname, cpt["epoch"]))

    device = th.device("cuda:{}".format(
        args.gpuid)) if args.gpuid >= 0 else th.device("cpu")
    nnet = nnet.to(device) if args.gpuid >= 0 else nnet

    # Load data
    dataset = Dataset(mix_scp=args.mix_scp,

    with th.no_grad():
        for i, data in enumerate(dataset):
            mix1 = th.tensor(data['mix1'], dtype=th.float32, device=device)
            mix2 = th.tensor(data['mix2'], dtype=th.float32, device=device)
            aux = th.tensor(data['aux'], dtype=th.float32, device=device)
            aux_len = th.tensor(data['aux_len'],

            if args.gpuid >= 0:
                mix1 = mix1.cuda()
                mix2 = mix2.cuda()
                aux = aux.cuda()
                aux_len = aux_len.cuda()

            # Forward
            ref = data['ref']
            ests, _ = nnet(mix1, mix2, aux, aux_len)
            ests = ests.cpu().numpy()
            if ests.size != ref.size:
                end = min(ests.size, ref.size)
                ests = ests[:end]
                ref = ref[:end]

            # for each utts
            # Compute SDRi
            if args.cal_sdr:
                SDR, sir, sar, popt = bss_eval_sources(ref, ests)
                # avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                total_SDR += SDR[0]
            # Compute SI-SNR
            SISNR = cal_SISNR(ests, ref)
            if args.cal_sdr:
                logger.info("Utt={:d} | SDR={:.2f} | SI-SNR={:.2f}".format(
                    total_cnt + 1, SDR[0], SISNR))
                logger.info("Utt={:d} | SI-SNR={:.2f}".format(
                    total_cnt + 1, SISNR))
            total_SISNR += SISNR
            total_cnt += 1
    end = time.time()

    logger.info('Time Elapsed: {:.1f}s'.format(end - start))
    if args.cal_sdr:
        logger.info("Average SDR: {0:.2f}".format(total_SDR / total_cnt))
    logger.info("Average SI-SNR: {:.2f}".format(total_SISNR / total_cnt))