Exemplo n.º 1
0
def main():
    apb = ArgumentParserBuilder()
    apb.add_options(opt('--model', type=str, choices=RegisteredModel.registered_names(), default='las'),
                    opt('--workspace', type=str, default=str(Path('workspaces') / 'default')))
    args = apb.parser.parse_args()
    use_frame = SETTINGS.training.objective == 'frame'
    ctx = InferenceContext(SETTINGS.training.vocab, token_type=SETTINGS.training.token_type, use_blank=not use_frame)

    ws = Workspace(Path(args.workspace), delete_existing=False)

    device = torch.device(SETTINGS.training.device)
    zmuv_transform = ZmuvTransform().to(device)
    model = RegisteredModel.find_registered_class(args.model)(ctx.num_labels).to(device).eval()
    zmuv_transform.load_state_dict(torch.load(str(ws.path / 'zmuv.pt.bin'), map_location=device))

    ws.load_model(model, best=True)
    model.streaming()
    if use_frame:
        engine = FrameInferenceEngine(int(SETTINGS.training.max_window_size_seconds * 1000),
                                      int(SETTINGS.training.eval_stride_size_seconds * 1000),
                                      SETTINGS.audio.sample_rate,
                                      model,
                                      zmuv_transform,
                                      negative_label=ctx.negative_label,
                                      coloring=ctx.coloring)
    else:
        engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate,
                                         model,
                                         zmuv_transform,
                                         negative_label=ctx.negative_label,
                                         coloring=ctx.coloring)
    client = InferenceClient(engine, device, SETTINGS.training.vocab)
    client.join()
Exemplo n.º 2
0
def _load_model(pretrained: bool, model_name: str, workspace_path: str,
                device: str,
                **kwargs) -> typing.Tuple[InferenceEngine, InferenceContext]:
    """
    Loads howl model from a workspace

    Arguments:
        pretrained (bool): load pretrained model weights
        model_name (str): name of the model to use
        workspace_path (str): relative path to workspace from root of howl-models

    Returns the inference engine and context
    """

    # Separate `reload_models` flag since PyTorch will pop the 'force_reload' flag
    reload_models = kwargs.pop("reload_models", False)
    cached_folder = _download_howl_models(reload_models)
    workspace_path = pathlib.Path(cached_folder) / workspace_path
    ws = howl_model.Workspace(workspace_path, delete_existing=False)

    # Load model settings
    settings = ws.load_settings()

    # Set up context
    use_frame = settings.training.objective == "frame"
    ctx = InferenceContext(settings.training.vocab,
                           token_type=settings.training.token_type,
                           use_blank=not use_frame)

    # Load models
    zmuv_transform = transform.ZmuvTransform()
    model = howl_model.RegisteredModel.find_registered_class(model_name)(
        ctx.num_labels).eval()

    # Load pretrained weights
    if pretrained:
        zmuv_transform.load_state_dict(
            torch.load(str(ws.path / "zmuv.pt.bin"),
                       map_location=torch.device(device)))
        ws.load_model(model, best=True)

    # Load engine
    model.streaming()
    if use_frame:
        engine = FrameInferenceEngine(
            int(settings.training.max_window_size_seconds * 1000),
            int(settings.training.eval_stride_size_seconds * 1000),
            model,
            zmuv_transform,
            ctx,
        )
    else:
        engine = SequenceInferenceEngine(model, zmuv_transform, ctx)

    return engine, ctx
Exemplo n.º 3
0
    def evaluate_engine(dataset: WakeWordDataset,
                        prefix: str,
                        save: bool = False,
                        positive_set: bool = False,
                        write_errors: bool = True,
                        mixer: DatasetMixer = None):
        std_transform.eval()

        if use_frame:
            engine = FrameInferenceEngine(
                int(SETTINGS.training.max_window_size_seconds * 1000),
                int(SETTINGS.training.eval_stride_size_seconds * 1000),
                SETTINGS.audio.sample_rate,
                model,
                zmuv_transform,
                negative_label=ctx.negative_label,
                coloring=ctx.coloring)
        else:
            engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate,
                                             model,
                                             zmuv_transform,
                                             negative_label=ctx.negative_label,
                                             coloring=ctx.coloring)
        model.eval()
        conf_matrix = ConfusionMatrix()
        pbar = tqdm(dataset, desc=prefix)
        if write_errors:
            with (ws.path / 'errors.tsv').open('a') as f:
                print(prefix, file=f)
        for idx, ex in enumerate(pbar):
            if mixer is not None:
                ex, = mixer([ex])
            audio_data = ex.audio_data.to(device)
            engine.reset()
            seq_present = engine.infer(audio_data)
            if seq_present != positive_set and write_errors:
                with (ws.path / 'errors.tsv').open('a') as f:
                    f.write(
                        f'{ex.metadata.transcription}\t{int(seq_present)}\t{int(positive_set)}\t{ex.metadata.path}\n'
                    )
            conf_matrix.increment(seq_present, positive_set)
            pbar.set_postfix(dict(mcc=f'{conf_matrix.mcc}',
                                  c=f'{conf_matrix}'))

        logging.info(f'{conf_matrix}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/tp', conf_matrix.tp, epoch_idx)
            ws.increment_model(model, conf_matrix.tp)
        if args.eval:
            threshold = engine.threshold
            with (ws.path /
                  (str(round(threshold, 2)) + '_results.csv')).open('a') as f:
                f.write(
                    f'{prefix},{threshold},{conf_matrix.tp},{conf_matrix.tn},{conf_matrix.fp},{conf_matrix.fn}\n'
                )
Exemplo n.º 4
0
def main():
    apb = ArgumentParserBuilder()
    apb.add_options(
        opt("--model",
            type=str,
            choices=RegisteredModel.registered_names(),
            default="las"),
        opt("--workspace",
            type=str,
            default=str(Path("workspaces") / "default")),
    )
    args = apb.parser.parse_args()
    ws = Workspace(Path(args.workspace), delete_existing=False)
    settings = ws.load_settings()

    use_frame = settings.training.objective == "frame"
    ctx = InferenceContext(settings.training.vocab,
                           token_type=settings.training.token_type,
                           use_blank=not use_frame)

    device = torch.device(settings.training.device)
    zmuv_transform = ZmuvTransform().to(device)
    model = RegisteredModel.find_registered_class(args.model)(
        ctx.num_labels).to(device).eval()
    zmuv_transform.load_state_dict(
        torch.load(str(ws.path / "zmuv.pt.bin"), map_location=device))

    ws.load_model(model, best=True)
    model.streaming()
    if use_frame:
        engine = FrameInferenceEngine(
            int(settings.training.max_window_size_seconds * 1000),
            int(settings.training.eval_stride_size_seconds * 1000),
            model,
            zmuv_transform,
            ctx,
        )
    else:
        engine = SequenceInferenceEngine(model, zmuv_transform, ctx)

    client = HowlClient(engine, ctx)
    client.start().join()