コード例 #1
0
def main():
    def copy_files(dataset, output_dir, deep_copy=False):
        print('copying files to', output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        for item in tqdm(dataset):
            output_path = output_dir / item.metadata.path.name
            try:
                if deep_copy:
                    copyfile(item.metadata.path, output_path)
                else:
                    os.symlink(item.metadata.path, output_path)
            except FileExistsError:
                pass

    def write_files(dataset, output_dir, mixer: DatasetMixer):
        print('copying files to', output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        for item in tqdm(dataset):
            item, = mixer([item])
            output_path = output_dir / item.metadata.path.name
            librosa.output.write_wav(output_path, item.audio_data.numpy(),
                                     16000)

    apb = ArgumentParserBuilder()
    apb.add_options(
        opt('--dataset-paths',
            '-i',
            type=str,
            nargs='+',
            default=[SETTINGS.dataset.dataset_path]),
        opt('--output-paths', '-o', type=str, default='data/precise'),
        opt('--deep-copy', action='store_true'))
    args = apb.parser.parse_args()

    if args.deep_copy:
        print('copying the audio files')
    else:
        print('generating symlink files')

    use_frame = SETTINGS.training.objective == 'frame'
    ctx = InferenceContext(SETTINGS.training.vocab,
                           token_type=SETTINGS.training.token_type,
                           use_blank=not use_frame)

    loader = WakeWordDatasetLoader()
    ds_kwargs = dict(sr=SETTINGS.audio.sample_rate,
                     mono=SETTINGS.audio.use_mono,
                     frame_labeler=ctx.labeler)

    inference_settings = SETTINGS.inference_engine
    wakeword = '_'.join(
        np.array(SETTINGS.training.vocab)[
            inference_settings.inference_sequence]).strip()

    ww_train_ds, ww_dev_ds, ww_test_ds = WakeWordDataset(metadata_list=[], set_type=DatasetType.TRAINING, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.DEV, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.TEST, **ds_kwargs)

    for ds_path in args.dataset_paths:
        ds_path = Path(ds_path)
        train_ds, dev_ds, test_ds = loader.load_splits(ds_path, **ds_kwargs)
        ww_train_ds.extend(train_ds)
        ww_dev_ds.extend(dev_ds)
        ww_test_ds.extend(test_ds)

    print_stats(f'Wake word dataset', ww_train_ds, ww_dev_ds, ww_test_ds)

    output_path = Path(args.output_paths) / wakeword

    ww_train_pos_ds = ww_train_ds.filter(
        lambda x: ctx.searcher.search(x.transcription), clone=True)
    print_stats(f'train positive dataset', ww_train_pos_ds)
    copy_files(ww_train_pos_ds, output_path / 'wake-word', args.deep_copy)

    ww_train_neg_ds = ww_train_ds.filter(
        lambda x: not ctx.searcher.search(x.transcription), clone=True)
    print_stats(f'train negative dataset', ww_train_neg_ds)
    copy_files(ww_train_neg_ds, output_path / 'not-wake-word', args.deep_copy)

    noise_ds = RecursiveNoiseDatasetLoader().load(
        Path(SETTINGS.raw_dataset.noise_dataset_path),
        sr=SETTINGS.audio.sample_rate,
        mono=SETTINGS.audio.use_mono)
    noise_ds_train, noise_ds_dev = noise_ds.split(Sha256Splitter(80))
    noise_ds_dev, noise_ds_test = noise_ds_dev.split(Sha256Splitter(50))

    dev_mixer = DatasetMixer(noise_ds_dev, seed=10, do_replace=False)
    test_mixer = DatasetMixer(noise_ds_test, seed=10, do_replace=False)

    ww_dev_pos_ds = ww_dev_ds.filter(
        lambda x: ctx.searcher.search(x.transcription), clone=True)
    print_stats(f'dev positive dataset', ww_dev_pos_ds)
    copy_files(ww_dev_pos_ds, output_path / 'dev/wake-word', args.deep_copy)
    write_files(ww_dev_pos_ds, output_path / 'noisy-dev/wake-word', dev_mixer)

    ww_dev_neg_ds = ww_dev_ds.filter(
        lambda x: not ctx.searcher.search(x.transcription), clone=True)
    print_stats(f'dev negative dataset', ww_dev_neg_ds)
    copy_files(ww_dev_neg_ds, output_path / 'dev/not-wake-word',
               args.deep_copy)
    write_files(ww_dev_neg_ds, output_path / 'noisy-dev/not-wake-word',
                dev_mixer)

    ww_test_pos_ds = ww_test_ds.filter(
        lambda x: ctx.searcher.search(x.transcription), clone=True)
    print_stats(f'test positive dataset', ww_test_pos_ds)
    copy_files(ww_test_pos_ds, output_path / 'test/wake-word', args.deep_copy)
    write_files(ww_test_pos_ds, output_path / 'noisy-test/wake-word',
                test_mixer)

    ww_test_neg_ds = ww_test_ds.filter(
        lambda x: not ctx.searcher.search(x.transcription), clone=True)
    print_stats(f'test negative dataset', ww_test_neg_ds)
    copy_files(ww_test_neg_ds, output_path / 'test/not-wake-word',
               args.deep_copy)
    write_files(ww_test_neg_ds, output_path / 'noisy-test/not-wake-word',
                test_mixer)
コード例 #2
0
ファイル: train.py プロジェクト: wangwei2009/howl
def main():
    def evaluate_engine(dataset: WakeWordDataset,
                        prefix: str,
                        save: bool = False,
                        positive_set: bool = False,
                        write_errors: bool = True,
                        mixer: DatasetMixer = None):
        std_transform.eval()

        if use_frame:
            engine = FrameInferenceEngine(int(SETTINGS.training.max_window_size_seconds * 1000),
                                          int(SETTINGS.training.eval_stride_size_seconds * 1000),
                                          SETTINGS.audio.sample_rate,
                                          model,
                                          zmuv_transform,
                                          negative_label=ctx.negative_label,
                                          coloring=ctx.coloring)
        else:
            engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate,
                                             model,
                                             zmuv_transform,
                                             negative_label=ctx.negative_label,
                                             coloring=ctx.coloring,
                                             blank_idx=ctx.blank_label)
        model.eval()
        conf_matrix = ConfusionMatrix()
        pbar = tqdm(dataset, desc=prefix)
        if write_errors:
            with (ws.path / 'errors.tsv').open('a') as f:
                print(prefix, file=f)
        for idx, ex in enumerate(pbar):
            if mixer is not None:
                ex, = mixer([ex])
            audio_data = ex.audio_data.to(device)
            engine.reset()
            seq_present = engine.infer(audio_data)
            if seq_present != positive_set and write_errors:
                with (ws.path / 'errors.tsv').open('a') as f:
                    f.write(f'{ex.metadata.transcription}\t{int(seq_present)}\t{int(positive_set)}\t{ex.metadata.path}\n')
            conf_matrix.increment(seq_present, positive_set)
            pbar.set_postfix(dict(mcc=f'{conf_matrix.mcc}', c=f'{conf_matrix}'))

        logging.info(f'{conf_matrix}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/tp', conf_matrix.tp, epoch_idx)
            ws.increment_model(model, conf_matrix.tp)
        if args.eval:
            threshold = engine.threshold
            with (ws.path / (str(round(threshold, 2)) + '_results.csv') ).open('a') as f:
                f.write(f'{prefix},{threshold},{conf_matrix.tp},{conf_matrix.tn},{conf_matrix.fp},{conf_matrix.fn}\n')

    def do_evaluate():
        evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True)
        evaluate_engine(ww_dev_neg_ds, 'Dev negative', positive_set=False)
        if SETTINGS.training.use_noise_dataset:
            evaluate_engine(ww_dev_pos_ds, 'Dev noisy positive', positive_set=True, mixer=dev_mixer)
            evaluate_engine(ww_dev_neg_ds, 'Dev noisy negative', positive_set=False, mixer=dev_mixer)
        evaluate_engine(ww_test_pos_ds, 'Test positive', positive_set=True)
        evaluate_engine(ww_test_neg_ds, 'Test negative', positive_set=False)
        if SETTINGS.training.use_noise_dataset:
            evaluate_engine(ww_test_pos_ds, 'Test noisy positive', positive_set=True, mixer=test_mixer)
            evaluate_engine(ww_test_neg_ds, 'Test noisy negative', positive_set=False, mixer=test_mixer)

    apb = ArgumentParserBuilder()
    apb.add_options(opt('--model', type=str, choices=RegisteredModel.registered_names(), default='las'),
                    opt('--workspace', type=str, default=str(Path('workspaces') / 'default')),
                    opt('--load-weights', action='store_true'),
                    opt('--load-last', action='store_true'),
                    opt('--no-dev-per-epoch', action='store_false', dest='dev_per_epoch'),
                    opt('--dataset-paths', '-i', type=str, nargs='+', default=[SETTINGS.dataset.dataset_path]),
                    opt('--eval', action='store_true'))
    args = apb.parser.parse_args()

    use_frame = SETTINGS.training.objective == 'frame'
    ctx = InferenceContext(SETTINGS.training.vocab, token_type=SETTINGS.training.token_type, use_blank=not use_frame)
    if use_frame:
        batchifier = WakeWordFrameBatchifier(ctx.negative_label,
                                             window_size_ms=int(SETTINGS.training.max_window_size_seconds * 1000))
        criterion = nn.CrossEntropyLoss()
    else:
        tokenizer = WakeWordTokenizer(ctx.vocab, ignore_oov=False)
        batchifier = AudioSequenceBatchifier(ctx.negative_label, tokenizer)
        criterion = nn.CTCLoss(ctx.blank_label)

    ws = Workspace(Path(args.workspace), delete_existing=not args.eval)
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)
    loader = WakeWordDatasetLoader()
    ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=ctx.labeler)

    ww_train_ds, ww_dev_ds, ww_test_ds = WakeWordDataset(metadata_list=[], set_type=DatasetType.TRAINING, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.DEV, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.TEST, **ds_kwargs)
    for ds_path in args.dataset_paths:
        ds_path = Path(ds_path)
        train_ds, dev_ds, test_ds = loader.load_splits(ds_path, **ds_kwargs)
        ww_train_ds.extend(train_ds)
        ww_dev_ds.extend(dev_ds)
        ww_test_ds.extend(test_ds)
    print_stats(f'Wake word dataset', ww_train_ds, ww_dev_ds, ww_test_ds)

    ww_dev_pos_ds = ww_dev_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    ww_dev_neg_ds = ww_dev_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)
    ww_test_pos_ds = ww_test_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    ww_test_neg_ds = ww_test_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)

    print_stats(f'Dev dataset', ww_dev_pos_ds, ww_dev_neg_ds)
    print_stats(f'Test dataset', ww_test_pos_ds, ww_test_neg_ds)
    device = torch.device(SETTINGS.training.device)
    std_transform = StandardAudioTransform().to(device).eval()
    zmuv_transform = ZmuvTransform().to(device)

    train_comp = (NoiseTransform().train(), batchifier)

    if SETTINGS.training.use_noise_dataset:
        noise_ds = RecursiveNoiseDatasetLoader().load(Path(SETTINGS.raw_dataset.noise_dataset_path),
                                                      sr=SETTINGS.audio.sample_rate,
                                                      mono=SETTINGS.audio.use_mono)
        logging.info(f'Loaded {len(noise_ds.metadata_list)} noise files.')
        noise_ds_train, noise_ds_dev = noise_ds.split(Sha256Splitter(80))
        noise_ds_dev, noise_ds_test = noise_ds_dev.split(Sha256Splitter(50))
        train_comp = (DatasetMixer(noise_ds_train).train(),) + train_comp
        dev_mixer = DatasetMixer(noise_ds_dev, seed=0, do_replace=False)
        test_mixer = DatasetMixer(noise_ds_test, seed=0, do_replace=False)
    train_comp = compose(*train_comp)

    prep_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=batchify).build(1)
    prep_dl.shuffle = True
    train_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size)

    model = RegisteredModel.find_registered_class(args.model)(ctx.num_labels).to(device).streaming()
    if SETTINGS.training.convert_static:
        model = ConvertedStaticModel(model, 40, 10)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params, SETTINGS.training.learning_rate, weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')

    if (ws.path / 'zmuv.pt.bin').exists():
        zmuv_transform.load_state_dict(torch.load(str(ws.path / 'zmuv.pt.bin')))
    else:
        for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')):
            batch.to(device)
            zmuv_transform.update(std_transform(batch.audio_data))
            if idx == 2000:  # TODO: quick debugging, remove later
                break
        logging.info(dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std))
    torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin'))

    if args.load_weights:
        ws.load_model(model, best=not args.load_last)
    if args.eval:
        ws.load_model(model, best=not args.load_last)
        do_evaluate()
        return

    ws.write_args(args)
    ws.write_settings(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs, position=0, leave=True):
        model.train()
        std_transform.train()
        model.streaming_state = None
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        total_loss = torch.Tensor([0.0]).to(device)
        for batch in pbar:
            batch.to(device)
            if use_frame:
                scores = model(zmuv_transform(std_transform(batch.audio_data)),
                               std_transform.compute_lengths(batch.lengths))
                loss = criterion(scores, batch.labels)
            else:
                lengths = std_transform.compute_lengths(batch.audio_lengths)
                scores = model(zmuv_transform(std_transform(batch.audio_data)), lengths)
                scores = F.log_softmax(scores, -1)  # [num_frames x batch_size x num_labels]
                lengths = torch.tensor([model.compute_length(x.item()) for x in lengths]).to(device)
                loss = criterion(scores, batch.labels, lengths, batch.label_lengths)
            optimizer.zero_grad()
            model.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            with torch.no_grad():
                total_loss += loss

        for group in optimizer.param_groups:
            group['lr'] *= SETTINGS.training.lr_decay

        mean = total_loss / len(train_dl)
        writer.add_scalar('Training/Loss', mean.item(), epoch_idx)
        writer.add_scalar('Training/LearningRate', group['lr'], epoch_idx)

        if args.dev_per_epoch:
            evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True, save=True, write_errors=False)

    do_evaluate()