def save_weights( self, filepath, overwrite=True, save_format=None, options=None, ): with file_util.save_file(filepath) as path: super().save_weights(filepath=path, overwrite=overwrite, save_format=save_format, options=options)
def save( self, filepath, overwrite=True, include_optimizer=True, save_format=None, signatures=None, options=None, save_traces=True, ): with file_util.save_file(filepath) as path: super().save( filepath=path, overwrite=overwrite, include_optimizer=include_optimizer, save_format=save_format, signatures=signatures, options=options, save_traces=save_traces, )
def run_testing( model: BaseModel, test_dataset: ASRSliceDataset, test_data_loader: tf.data.Dataset, output: str, ): with file_util.save_file(file_util.preprocess_paths(output)) as filepath: overwrite = True if tf.io.gfile.exists(filepath): overwrite = input(f"Overwrite existing result file {filepath} ? (y/n): ").lower() == "y" if overwrite: results = model.predict(test_data_loader, verbose=1) logger.info(f"Saving result to {output} ...") with open(filepath, "w") as openfile: openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n") progbar = tqdm(total=test_dataset.total_steps, unit="batch") for i, pred in enumerate(results): groundtruth, greedy, beamsearch = [x.decode("utf-8") for x in pred] path, duration, _ = test_dataset.entries[i] openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n") progbar.update(1) progbar.close() app_util.evaluate_results(filepath)
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.test_dataset_config)) # build model jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes) jasper.make(speech_featurizer.shape) jasper.load_weights(args.saved) jasper.summary(line_length=100) jasper.add_featurizers(speech_featurizer, text_featurizer) batch_size = args.bs or config.learning_config.running_config.batch_size test_data_loader = test_dataset.create(batch_size) with file_util.save_file(file_util.preprocess_paths(args.output)) as filepath: overwrite = True if tf.io.gfile.exists(filepath): overwrite = input( f"Overwrite existing result file {filepath} ? (y/n): ").lower( ) == "y" if overwrite: results = jasper.predict(test_data_loader, verbose=1) print(f"Saving result to {args.output} ...") with open(filepath, "w") as openfile: openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n") progbar = tqdm(total=test_dataset.total_steps, unit="batch") for i, pred in enumerate(results): groundtruth, greedy, beamsearch = [ x.decode('utf-8') for x in pred ]