def test_contains(): st = SymbolsTable() st.add("a", 1) assert "a" in st assert 1 in st assert 2 not in st with pytest.raises(ValueError, match="SymbolsTable contains pairs"): assert None in st # noqa: expected type
def test_iterator(self): st = SymbolsTable() st.add("a", 1) st.add("b", 2) it = iter(st) self.assertEqual(next(it), ("a", 1)) self.assertEqual(next(it), ("b", 2)) with self.assertRaises(StopIteration): next(it)
def test_iterator(): st = SymbolsTable() st.add("a", 1) st.add("b", 2) it = iter(st) assert next(it) == ("a", 1) assert next(it) == ("b", 2) with pytest.raises(StopIteration): next(it)
def test_getitem(): st = SymbolsTable() st.add("a", 1) st.add("b", 2) assert st["a"] == 1 assert st["b"] == 2 assert st[1] == "a" assert st[2] == "b" assert st[-9] is None assert st["c"] is None
def test_getitem(self): st = SymbolsTable() st.add("a", 1) st.add("b", 2) self.assertEqual(st["a"], 1) self.assertEqual(st["b"], 2) self.assertEqual(st[1], "a") self.assertEqual(st[2], "b") self.assertEqual(st[-9], None) self.assertEqual(st["c"], None)
def kws_assessment_column_index( syms, # type: SymbolsTable delimiters, # type: Iterable[int] kws_ref, # type: AnyStr lattice_ark, # type: AnyStr acoustic_scale, # type: float nbest, # type: int queries=None, # type: Optional[AnyStr] verbose=False, # type: bool max_states=None, # type: Optional[int] max_arcs=None, # type: Optional[int] char_sep=u"", # type: str compute_curve=False, # type: bool ): p1, word_syms_str = make_posteriorgram_process(delimiters, lattice_ark, acoustic_scale, verbose, max_states, max_arcs) index = {} for line in p1.stdout: line = line.decode("utf-8").strip() m = re.match(r"^([^ ]+) (.)+$", line) utt = m.group(1) best_score = {} for m in re.finditer(r"\[ ([0-9]+ [0-9.e-]+ )+\]", line): frame = m.group(0) for m in re.finditer(r"([0-9]+) ([0-9.e-]+) ", frame): word = int(m.group(1)) score = float(m.group(2)) if word not in best_score or best_score[word] < score: best_score[word] = score aux = [(score, word) for word, score in best_score.items()] aux.sort(reverse=True) index[utt] = aux[:nbest] word_syms = SymbolsTable(word_syms_str) queries_set = get_include_words_set(queries) kws_ref_set = get_kws_ref_set(kws_ref) fd, tmppath = tempfile.mkstemp() tmpf = os.fdopen(fd, "w") kws_hyp_set = set() for utt in index: for score, word in index[utt]: word = char_sep.join( [syms[int(x)] for x in word_syms[word].split("_")]) if queries_set is None or word in queries_set: rel = 1 if (word, utt) in kws_ref_set else 0 tmpf.write(u"{} {} {} {}\n".format(utt, word, rel, score)) kws_hyp_set.add((word, utt)) add_missing_words(kws_ref_set, kws_hyp_set, tmpf) p2 = make_kws_assessment_process(tmppath, queries, verbose, compute_curve) out = p2.communicate()[0].decode("utf-8") os.remove(tmppath) os.remove(word_syms_str) if compute_curve: return out else: return kws_assessment_parse_output(out)
def test_load(self): table_file = StringIO(u"\n\na 1\nb 2\n") st = SymbolsTable(table_file) self.assertEqual(len(st), 2) self.assertEqual(st["a"], 1) self.assertEqual(st["b"], 2) self.assertEqual(st[1], "a") self.assertEqual(st[2], "b")
def test_load(tmpdir, as_type): file = tmpdir / "f" file.write_text("\n\na 1\nb 2\n", "utf-8") st = SymbolsTable(as_type(file)) assert len(st) == 2 assert st["a"] == 1 assert st["b"] == 2 assert st[1] == "a" assert st[2] == "b"
def kws_assessment_utterance_index( syms, # type: SymbolsTable delimiters, # type: Iterable[int] kws_ref, # type: AnyStr lattice_ark, # type: AnyStr acoustic_scale, # type: float nbest, # type: int queries=None, # type: Optional[AnyStr] verbose=False, # type: bool max_states=None, # type: Optional[int] max_arcs=None, # type: Optional[int] char_sep=u"", # type: str compute_curve=False, # type: bool ): p1, word_syms_str = make_index_utterance_process( delimiters, lattice_ark, acoustic_scale, syms, verbose, max_states, max_arcs, queries, char_sep, nbest, ) word_syms = SymbolsTable(word_syms_str) kws_ref_set = get_kws_ref_set(kws_ref) fd, tmppath = tempfile.mkstemp() tmpf = os.fdopen(fd, "w") kws_hyp_set = set() for line in p1.stdout: line = line.decode("utf-8").split() utt = line[0] for i in range(1, len(line), 3): word = word_syms[int(line[i])] word = char_sep.join([syms[int(x)] for x in word.split("_")]) score = line[i + 1] rel = 1 if (word, utt) in kws_ref_set else 0 tmpf.write(u"{} {} {} {}\n".format(utt, word, rel, score)) kws_hyp_set.add((word, utt)) p1.stdout.close() add_missing_words(kws_ref_set, kws_hyp_set, tmpf) p2 = make_kws_assessment_process(tmppath, queries, verbose, compute_curve) out = p2.communicate()[0].decode("utf-8") os.remove(tmppath) os.remove(word_syms_str) if compute_curve: return out else: return kws_assessment_parse_output(out)
def test_save(tmpdir): st = SymbolsTable() st.add("a", 1) st.add("b", 2) st_file = tmpdir / "syms" st.save(str(st_file)) assert st_file.read_text("utf-8") == "a 1\nb 2\n"
def test_add_repeated_value(self): st = SymbolsTable() st.add("a", 0) with self.assertRaises( KeyError, msg= 'Value "0" was already present in the table (assigned to symbol "a")', ): st.add("b", 0)
def test_save(self): st = SymbolsTable() st.add("a", 1) st.add("b", 2) table_file = NamedTemporaryFile(delete=False) st.save(table_file) with open(table_file.name, "r") as f: table_content = f.read() self.assertEqual(table_content, "a 1\nb 2\n") os.remove(table_file.name)
def test_add_repeated_value(): st = SymbolsTable() st.add("a", 0) with pytest.raises( KeyError, match=re.escape( 'Value "0" was already present in the table (assigned to symbol "a")' ), ): st.add("b", 0)
def test_load_value_error(self): table_file = StringIO(u"\n\na 1\nb c\n") with self.assertRaises(ValueError): SymbolsTable(table_file)
def test_add(): st = SymbolsTable() st.add("<eps>", 0) assert len(st) == 1 st.add("a", 1) assert len(st) == 2
def test_add_valid_repeated(): st = SymbolsTable() st.add("<eps>", 0) st.add("<eps>", 0) assert len(st) == 1
def test_add_valid_repeated(self): st = SymbolsTable() st.add("<eps>", 0) st.add("<eps>", 0) self.assertEqual(len(st), 1)
def test_add(self): st = SymbolsTable() st.add("<eps>", 0) self.assertEqual(len(st), 1) st.add("a", 1) self.assertEqual(len(st), 2)
def test_empty(self): st = SymbolsTable() self.assertEqual(len(st), 0)
default=0.5, help="Plot objects with this minimum relevance", ) parser.add_argument( "index_type", choices=("position", "segment"), help="Type of the KWS index to process", ) parser.add_argument("imgs_dir", type=str, help="Directory containing the indexed images") parser.add_argument("index", type=argparse.FileType("r"), help="File containing the KWS index") args = parser.parse_args() syms = SymbolsTable(args.symbols_table) if args.symbols_table else None batch_images = [] for sample in args.index: m = re.match(r"^([^ ]+) +(.+)$", sample) sample_id = m.group(1) print(sample_id) matches = m.group(2).split(";") # Parse index entries if args.index_type == "position": matches = [parse_position_match(m, syms) for m in matches] matches = sorted(matches, key=lambda x: x.position) elif args.index_type == "segment": matches = [parse_segment_match(m, syms) for m in matches] matches = sorted(matches, key=lambda x: x.beg)
def test_load_value_error(tmpdir): file = tmpdir / "f" file.write_text("\n\na 1\nb c\n", "utf-8") with pytest.raises(ValueError): SymbolsTable(file)
def test_empty(): st = SymbolsTable() assert len(st) == 0
default=16, help='Average adaptive pooling of the images before the ' 'LSTM layers') add_argument('--lstm_hidden_size', type=int, default=128) add_argument('--lstm_num_layers', type=int, default=1) add_argument('--add_softmax', action='store_true') add_argument('--add_boundary_blank', action='store_true') add_argument('syms', help='Symbols table mapping from strings to integers') add_argument('img_dir', help='Directory containing word images') add_argument('gt_file', help='') add_argument('checkpoint', help='') add_argument('output', type=argparse.FileType('w')) args = args() # Build neural network syms = SymbolsTable(args.syms) model = build_ctc_model(num_outputs=len(syms), adaptive_pool_height=args.adaptive_pool_height, lstm_hidden_size=args.lstm_hidden_size, lstm_num_layers=args.lstm_num_layers) # Load checkpoint ckpt = torch.load(args.checkpoint) if 'model' in ckpt and 'optimizer' in ckpt: model.load_state_dict(ckpt['model']) else: model.load_state_dict(ckpt) # Ensure parameters are in the correct device model.eval() if args.gpu > 0:
params_to_optimize.append("acoustic_scale") space.append( hp.quniform( "acoustic_scale", args.acoustic_scale_min, args.acoustic_scale_max, args.acoustic_scale_quant, )) acoustic_scale_global = None acoustic_scale_key = len(space) - 1 else: acoustic_scale_global = args.acoustic_scale_max acoustic_scale_key = None syms = [ SymbolsTable(args.syms_pattern.format(cv=cv)) for cv in range(args.num_partitions) ] @lru_cache(maxsize=None) def objective(params): if prior_scale_key is not None: prior_scale = params[prior_scale_key] else: prior_scale = prior_scale_global if acoustic_scale_key is not None: acoustic_scale = params[acoustic_scale_key] else: acoustic_scale = acoustic_scale_global