コード例 #1
0
def run(args):
    stft_kwargs = {
        "frame_len": args.frame_len,
        "frame_hop": args.frame_hop,
        "round_power_of_two": args.round_power_of_two,
        "window": args.window,
        "center": args.center,
        "transpose": False
    }
    np.random.seed(args.seed)
    spectrogram_reader = SpectrogramReader(args.wav_scp, **stft_kwargs)
    MaskReader = {"numpy": NumpyReader, "kaldi": ScriptReader}
    init_mask_reader = MaskReader[args.fmt](
        args.init_mask) if args.init_mask else None

    num_done = 0
    with NumpyWriter(args.dst_dir) as writer:
        dst_dir = Path(args.dst_dir)
        for key, stft in spectrogram_reader:
            if not (dst_dir / f"{key}.npy").exists():
                init_mask = None
                if init_mask_reader and key in init_mask_reader:
                    init_mask = init_mask_reader[key]
                    # T x F => F x T
                    if init_mask.ndim == 2:
                        init_mask = np.transpose(init_mask)
                    else:
                        init_mask = np.transpose(init_mask, (0, 2, 1))
                    logger.info("Using external TF-mask to initialize cgmm")
                # stft: N x F x T
                trainer = CgmmTrainer(stft,
                                      args.num_classes,
                                      gamma=init_mask,
                                      update_alpha=args.update_alpha)
                try:
                    masks = trainer.train(args.num_iters)
                    # K x F x T => K x T x F
                    masks = np.transpose(masks, (0, 2, 1))
                    num_done += 1
                    if args.solve_permu:
                        masks = permu_aligner(masks)
                        logger.info(
                            "Permutation alignment done on each frequency")
                    if args.num_classes == 2:
                        masks = masks[0]
                    writer.write(key, masks.astype(np.float32))
                    logger.info(f"Training utterance {key} ... Done")
                except RuntimeError:
                    logger.warn(f"Training utterance {key} ... Failed")
            else:
                logger.info(f"Training utterance {key} ... Skip")
    logger.info(
        f"Train {num_done:d} utterances over {len(spectrogram_reader):d}")
コード例 #2
0
def run(args):
    stft_kwargs = {
        "frame_len": args.frame_len,
        "frame_hop": args.frame_hop,
        "round_power_of_two": args.round_power_of_two,
        "window": args.window,
        "center": args.center,
        "transpose": False
    }
    np.random.seed(args.seed)
    spectrogram_reader = SpectrogramReader(args.wav_scp, **stft_kwargs)
    MaskReader = {"numpy": NumpyReader, "kaldi": ScriptReader}
    init_mask_reader = MaskReader[args.fmt](
        args.init_mask) if args.init_mask else None

    num_done = 0
    with NumpyWriter(args.dst_dir) as writer:
        dst_dir = Path(args.dst_dir)
        for key, stft in spectrogram_reader:
            if not (dst_dir / f"{key}.npy").exists():
                # K x F x T
                init_mask = None
                if init_mask_reader and key in init_mask_reader:
                    init_mask = init_mask_reader[key]
                    logger.info("Using external mask to initialize cacgmm")
                # stft: N x F x T
                trainer = CacgmmTrainer(stft,
                                        args.num_classes,
                                        gamma=init_mask,
                                        cgmm_init=args.cgmm_init)
                try:
                    # EM progress
                    masks = trainer.train(args.num_epoches)
                    # align if needed
                    if not args.cgmm_init or args.num_classes != 2:
                        # K x F x T => K x T x F
                        masks = permu_aligner(masks, transpose=True)
                        logger.info(
                            "Permutation alignment done on each frequency")
                    num_done += 1
                    writer.write(key, masks.astype("float32"))
                    logger.info(f"Training utterance {key} ... Done")
                except np.linalg.LinAlgError:
                    logger.warn(f"Training utterance {key} ... Failed")
            else:
                logger.info(f"Training utterance {key} ... Skip")
    logger.info(
        f"Train {num_done:d} utterances over {len(spectrogram_reader):d}")