def main(): apb = ArgumentParserBuilder() apb.add_options(opt('--model', type=str, choices=RegisteredModel.registered_names(), default='las'), opt('--workspace', type=str, default=str(Path('workspaces') / 'default'))) args = apb.parser.parse_args() use_frame = SETTINGS.training.objective == 'frame' ctx = InferenceContext(SETTINGS.training.vocab, token_type=SETTINGS.training.token_type, use_blank=not use_frame) ws = Workspace(Path(args.workspace), delete_existing=False) device = torch.device(SETTINGS.training.device) zmuv_transform = ZmuvTransform().to(device) model = RegisteredModel.find_registered_class(args.model)(ctx.num_labels).to(device).eval() zmuv_transform.load_state_dict(torch.load(str(ws.path / 'zmuv.pt.bin'), map_location=device)) ws.load_model(model, best=True) model.streaming() if use_frame: engine = FrameInferenceEngine(int(SETTINGS.training.max_window_size_seconds * 1000), int(SETTINGS.training.eval_stride_size_seconds * 1000), SETTINGS.audio.sample_rate, model, zmuv_transform, negative_label=ctx.negative_label, coloring=ctx.coloring) else: engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate, model, zmuv_transform, negative_label=ctx.negative_label, coloring=ctx.coloring) client = InferenceClient(engine, device, SETTINGS.training.vocab) client.join()
def _load_model(pretrained: bool, model_name: str, workspace_path: str, device: str, **kwargs) -> typing.Tuple[InferenceEngine, InferenceContext]: """ Loads howl model from a workspace Arguments: pretrained (bool): load pretrained model weights model_name (str): name of the model to use workspace_path (str): relative path to workspace from root of howl-models Returns the inference engine and context """ # Separate `reload_models` flag since PyTorch will pop the 'force_reload' flag reload_models = kwargs.pop("reload_models", False) cached_folder = _download_howl_models(reload_models) workspace_path = pathlib.Path(cached_folder) / workspace_path ws = howl_model.Workspace(workspace_path, delete_existing=False) # Load model settings settings = ws.load_settings() # Set up context use_frame = settings.training.objective == "frame" ctx = InferenceContext(settings.training.vocab, token_type=settings.training.token_type, use_blank=not use_frame) # Load models zmuv_transform = transform.ZmuvTransform() model = howl_model.RegisteredModel.find_registered_class(model_name)( ctx.num_labels).eval() # Load pretrained weights if pretrained: zmuv_transform.load_state_dict( torch.load(str(ws.path / "zmuv.pt.bin"), map_location=torch.device(device))) ws.load_model(model, best=True) # Load engine model.streaming() if use_frame: engine = FrameInferenceEngine( int(settings.training.max_window_size_seconds * 1000), int(settings.training.eval_stride_size_seconds * 1000), model, zmuv_transform, ctx, ) else: engine = SequenceInferenceEngine(model, zmuv_transform, ctx) return engine, ctx
def evaluate_engine(dataset: WakeWordDataset, prefix: str, save: bool = False, positive_set: bool = False, write_errors: bool = True, mixer: DatasetMixer = None): std_transform.eval() if use_frame: engine = FrameInferenceEngine( int(SETTINGS.training.max_window_size_seconds * 1000), int(SETTINGS.training.eval_stride_size_seconds * 1000), SETTINGS.audio.sample_rate, model, zmuv_transform, negative_label=ctx.negative_label, coloring=ctx.coloring) else: engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate, model, zmuv_transform, negative_label=ctx.negative_label, coloring=ctx.coloring) model.eval() conf_matrix = ConfusionMatrix() pbar = tqdm(dataset, desc=prefix) if write_errors: with (ws.path / 'errors.tsv').open('a') as f: print(prefix, file=f) for idx, ex in enumerate(pbar): if mixer is not None: ex, = mixer([ex]) audio_data = ex.audio_data.to(device) engine.reset() seq_present = engine.infer(audio_data) if seq_present != positive_set and write_errors: with (ws.path / 'errors.tsv').open('a') as f: f.write( f'{ex.metadata.transcription}\t{int(seq_present)}\t{int(positive_set)}\t{ex.metadata.path}\n' ) conf_matrix.increment(seq_present, positive_set) pbar.set_postfix(dict(mcc=f'{conf_matrix.mcc}', c=f'{conf_matrix}')) logging.info(f'{conf_matrix}') if save and not args.eval: writer.add_scalar(f'{prefix}/Metric/tp', conf_matrix.tp, epoch_idx) ws.increment_model(model, conf_matrix.tp) if args.eval: threshold = engine.threshold with (ws.path / (str(round(threshold, 2)) + '_results.csv')).open('a') as f: f.write( f'{prefix},{threshold},{conf_matrix.tp},{conf_matrix.tn},{conf_matrix.fp},{conf_matrix.fn}\n' )
def main(): apb = ArgumentParserBuilder() apb.add_options( opt("--model", type=str, choices=RegisteredModel.registered_names(), default="las"), opt("--workspace", type=str, default=str(Path("workspaces") / "default")), ) args = apb.parser.parse_args() ws = Workspace(Path(args.workspace), delete_existing=False) settings = ws.load_settings() use_frame = settings.training.objective == "frame" ctx = InferenceContext(settings.training.vocab, token_type=settings.training.token_type, use_blank=not use_frame) device = torch.device(settings.training.device) zmuv_transform = ZmuvTransform().to(device) model = RegisteredModel.find_registered_class(args.model)( ctx.num_labels).to(device).eval() zmuv_transform.load_state_dict( torch.load(str(ws.path / "zmuv.pt.bin"), map_location=device)) ws.load_model(model, best=True) model.streaming() if use_frame: engine = FrameInferenceEngine( int(settings.training.max_window_size_seconds * 1000), int(settings.training.eval_stride_size_seconds * 1000), model, zmuv_transform, ctx, ) else: engine = SequenceInferenceEngine(model, zmuv_transform, ctx) client = HowlClient(engine, ctx) client.start().join()