Exemplo n.º 1
0
def run(args):
    num_bins, config_dict = parse_yaml(args.config)
    dataloader_conf = config_dict["dataloader"]
    spectrogram_conf = config_dict["spectrogram_reader"]
    # Load cmvn
    dict_mvn = dataloader_conf["mvn_dict"]
    if dict_mvn:
        if not os.path.exists(dict_mvn):
            raise FileNotFoundError("Could not find mvn files")
        with open(dict_mvn, "rb") as f:
            dict_mvn = pickle.load(f)
    # default: True
    apply_log = dataloader_conf[
        "apply_log"] if "apply_log" in dataloader_conf else True

    dcnet = PITNet(num_bins, **config_dict["model"])

    frame_length = spectrogram_conf["frame_length"]
    frame_shift = spectrogram_conf["frame_shift"]
    window = spectrogram_conf["window"]

    separator = Separator(dcnet, args.state_dict, cuda=args.cuda)

    utt_dict = parse_scps(args.wave_scp)
    num_utts = 0
    for key, utt in utt_dict.items():
        try:
            samps, stft_mat = stft(utt,
                                   frame_length=frame_length,
                                   frame_shift=frame_shift,
                                   window=window,
                                   center=True,
                                   return_samps=True)
        except FileNotFoundError:
            print("Skip utterance {}... not found".format(key))
            continue
        print("Processing utterance {}".format(key))
        num_utts += 1
        norm = np.linalg.norm(samps, np.inf)
        spk_mask, spk_spectrogram = separator.seperate(stft_mat,
                                                       cmvn=dict_mvn,
                                                       apply_log=apply_log)

        for index, stft_mat in enumerate(spk_spectrogram):
            istft(os.path.join(args.dump_dir,
                               '{}.spk{}.wav'.format(key, index + 1)),
                  stft_mat,
                  frame_length=frame_length,
                  frame_shift=frame_shift,
                  window=window,
                  center=True,
                  norm=norm,
                  fs=8000,
                  nsamps=samps.size)
            if args.dump_mask:
                sio.savemat(
                    os.path.join(args.dump_dir,
                                 '{}.spk{}.mat'.format(key, index + 1)),
                    {"mask": spk_mask[index]})
    print("Processed {} utterance!".format(num_utts))
Exemplo n.º 2
0
def train(args):
    gpuid = tuple(map(int, args.gpus.split(',')))
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["model"]

    logger.info("Training with {}".format(
        "IRM" if reader_conf["apply_abs"] else "PSM"))
    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(config_dict["train_scp_conf"]
                             if not debug else config_dict["debug_scp_conf"],
                             reader_conf,
                             loader_conf,
                             train=True)
    valid_loader = uttloader(config_dict["valid_scp_conf"]
                             if not debug else config_dict["debug_scp_conf"],
                             reader_conf,
                             loader_conf,
                             train=False)
    checkpoint = config_dict["trainer"]["checkpoint"]
    logger.info("Training for {} epoches -> {}...".format(
        args.num_epoches,
        "default checkpoint" if checkpoint is None else checkpoint))

    nnet = PITNet(num_bins, **dcnnet_conf)
    trainer = PITrainer(nnet, **config_dict["trainer"], gpuid=gpuid)
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)
Exemplo n.º 3
0
def train(args):
    debug = args.debug
    logger.info(
        "Start training in {} model".format('debug' if debug else 'normal'))
    num_bins, config_dict = parse_yaml(args.config)
    reader_conf = config_dict["spectrogram_reader"]
    loader_conf = config_dict["dataloader"]
    dcnnet_conf = config_dict["model"]
    state_dict = args.state_dict

    location = "cpu" if args.cpu else None

    logger.info("Training with {}".format("IRM" if reader_conf["apply_abs"]
                                          else "PSM"))
    batch_size = loader_conf["batch_size"]
    logger.info(
        "Training in {}".format("per utterance" if batch_size == 1 else
                                '{} utterance per batch'.format(batch_size)))

    train_loader = uttloader(
        config_dict["train_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=True)
    valid_loader = uttloader(
        config_dict["valid_scp_conf"]
        if not debug else config_dict["debug_scp_conf"],
        reader_conf,
        loader_conf,
        train=False)
    checkpoint = config_dict["trainer"]["checkpoint"]
    logger.info("Training for {} epoches -> {}...".format(
        args.num_epoches, "default checkpoint"
        if checkpoint is None else checkpoint))

    nnet = PITNet(num_bins, **dcnnet_conf)

    if(state_dict != ""):
        if not os.path.exists(state_dict):
            raise ValueError("there is no path {}".format(state_dict))
        else:
            logger.info("load {}".format(state_dict))
            nnet.load_state_dict(th.load(state_dict, map_location=location))

    trainer = PITrainer(nnet, **config_dict["trainer"])
    trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches, start=args.start)