示例#1
0
def test_contains():
    st = SymbolsTable()
    st.add("a", 1)
    assert "a" in st
    assert 1 in st
    assert 2 not in st
    with pytest.raises(ValueError, match="SymbolsTable contains pairs"):
        assert None in st  # noqa: expected type
示例#2
0
 def test_iterator(self):
     st = SymbolsTable()
     st.add("a", 1)
     st.add("b", 2)
     it = iter(st)
     self.assertEqual(next(it), ("a", 1))
     self.assertEqual(next(it), ("b", 2))
     with self.assertRaises(StopIteration):
         next(it)
示例#3
0
def test_iterator():
    st = SymbolsTable()
    st.add("a", 1)
    st.add("b", 2)
    it = iter(st)
    assert next(it) == ("a", 1)
    assert next(it) == ("b", 2)
    with pytest.raises(StopIteration):
        next(it)
示例#4
0
def test_getitem():
    st = SymbolsTable()
    st.add("a", 1)
    st.add("b", 2)
    assert st["a"] == 1
    assert st["b"] == 2
    assert st[1] == "a"
    assert st[2] == "b"
    assert st[-9] is None
    assert st["c"] is None
示例#5
0
 def test_getitem(self):
     st = SymbolsTable()
     st.add("a", 1)
     st.add("b", 2)
     self.assertEqual(st["a"], 1)
     self.assertEqual(st["b"], 2)
     self.assertEqual(st[1], "a")
     self.assertEqual(st[2], "b")
     self.assertEqual(st[-9], None)
     self.assertEqual(st["c"], None)
def kws_assessment_column_index(
        syms,  # type: SymbolsTable
        delimiters,  # type: Iterable[int]
        kws_ref,  # type: AnyStr
        lattice_ark,  # type: AnyStr
        acoustic_scale,  # type: float
        nbest,  # type: int
        queries=None,  # type: Optional[AnyStr]
        verbose=False,  # type: bool
        max_states=None,  # type: Optional[int]
        max_arcs=None,  # type: Optional[int]
        char_sep=u"",  # type: str
        compute_curve=False,  # type: bool
):
    p1, word_syms_str = make_posteriorgram_process(delimiters, lattice_ark,
                                                   acoustic_scale, verbose,
                                                   max_states, max_arcs)

    index = {}
    for line in p1.stdout:
        line = line.decode("utf-8").strip()
        m = re.match(r"^([^ ]+) (.)+$", line)
        utt = m.group(1)
        best_score = {}
        for m in re.finditer(r"\[ ([0-9]+ [0-9.e-]+ )+\]", line):
            frame = m.group(0)
            for m in re.finditer(r"([0-9]+) ([0-9.e-]+) ", frame):
                word = int(m.group(1))
                score = float(m.group(2))
                if word not in best_score or best_score[word] < score:
                    best_score[word] = score
        aux = [(score, word) for word, score in best_score.items()]
        aux.sort(reverse=True)
        index[utt] = aux[:nbest]

    word_syms = SymbolsTable(word_syms_str)
    queries_set = get_include_words_set(queries)
    kws_ref_set = get_kws_ref_set(kws_ref)
    fd, tmppath = tempfile.mkstemp()
    tmpf = os.fdopen(fd, "w")
    kws_hyp_set = set()
    for utt in index:
        for score, word in index[utt]:
            word = char_sep.join(
                [syms[int(x)] for x in word_syms[word].split("_")])
            if queries_set is None or word in queries_set:
                rel = 1 if (word, utt) in kws_ref_set else 0
                tmpf.write(u"{} {} {} {}\n".format(utt, word, rel, score))
                kws_hyp_set.add((word, utt))

    add_missing_words(kws_ref_set, kws_hyp_set, tmpf)
    p2 = make_kws_assessment_process(tmppath, queries, verbose, compute_curve)
    out = p2.communicate()[0].decode("utf-8")
    os.remove(tmppath)
    os.remove(word_syms_str)
    if compute_curve:
        return out
    else:
        return kws_assessment_parse_output(out)
示例#7
0
 def test_load(self):
     table_file = StringIO(u"\n\na   1\nb     2\n")
     st = SymbolsTable(table_file)
     self.assertEqual(len(st), 2)
     self.assertEqual(st["a"], 1)
     self.assertEqual(st["b"], 2)
     self.assertEqual(st[1], "a")
     self.assertEqual(st[2], "b")
示例#8
0
def test_load(tmpdir, as_type):
    file = tmpdir / "f"
    file.write_text("\n\na   1\nb     2\n", "utf-8")
    st = SymbolsTable(as_type(file))
    assert len(st) == 2
    assert st["a"] == 1
    assert st["b"] == 2
    assert st[1] == "a"
    assert st[2] == "b"
def kws_assessment_utterance_index(
        syms,  # type: SymbolsTable
        delimiters,  # type: Iterable[int]
        kws_ref,  # type: AnyStr
        lattice_ark,  # type: AnyStr
        acoustic_scale,  # type: float
        nbest,  # type: int
        queries=None,  # type: Optional[AnyStr]
        verbose=False,  # type: bool
        max_states=None,  # type: Optional[int]
        max_arcs=None,  # type: Optional[int]
        char_sep=u"",  # type: str
        compute_curve=False,  # type: bool
):
    p1, word_syms_str = make_index_utterance_process(
        delimiters,
        lattice_ark,
        acoustic_scale,
        syms,
        verbose,
        max_states,
        max_arcs,
        queries,
        char_sep,
        nbest,
    )

    word_syms = SymbolsTable(word_syms_str)
    kws_ref_set = get_kws_ref_set(kws_ref)
    fd, tmppath = tempfile.mkstemp()
    tmpf = os.fdopen(fd, "w")
    kws_hyp_set = set()
    for line in p1.stdout:
        line = line.decode("utf-8").split()
        utt = line[0]
        for i in range(1, len(line), 3):
            word = word_syms[int(line[i])]
            word = char_sep.join([syms[int(x)] for x in word.split("_")])
            score = line[i + 1]
            rel = 1 if (word, utt) in kws_ref_set else 0
            tmpf.write(u"{} {} {} {}\n".format(utt, word, rel, score))
            kws_hyp_set.add((word, utt))
    p1.stdout.close()

    add_missing_words(kws_ref_set, kws_hyp_set, tmpf)
    p2 = make_kws_assessment_process(tmppath, queries, verbose, compute_curve)
    out = p2.communicate()[0].decode("utf-8")
    os.remove(tmppath)
    os.remove(word_syms_str)
    if compute_curve:
        return out
    else:
        return kws_assessment_parse_output(out)
示例#10
0
def test_save(tmpdir):
    st = SymbolsTable()
    st.add("a", 1)
    st.add("b", 2)
    st_file = tmpdir / "syms"
    st.save(str(st_file))
    assert st_file.read_text("utf-8") == "a 1\nb 2\n"
示例#11
0
 def test_add_repeated_value(self):
     st = SymbolsTable()
     st.add("a", 0)
     with self.assertRaises(
             KeyError,
             msg=
             'Value "0" was already present in the table (assigned to symbol "a")',
     ):
         st.add("b", 0)
示例#12
0
 def test_save(self):
     st = SymbolsTable()
     st.add("a", 1)
     st.add("b", 2)
     table_file = NamedTemporaryFile(delete=False)
     st.save(table_file)
     with open(table_file.name, "r") as f:
         table_content = f.read()
     self.assertEqual(table_content, "a 1\nb 2\n")
     os.remove(table_file.name)
示例#13
0
def test_add_repeated_value():
    st = SymbolsTable()
    st.add("a", 0)
    with pytest.raises(
            KeyError,
            match=re.escape(
                'Value "0" was already present in the table (assigned to symbol "a")'
            ),
    ):
        st.add("b", 0)
示例#14
0
 def test_load_value_error(self):
     table_file = StringIO(u"\n\na   1\nb     c\n")
     with self.assertRaises(ValueError):
         SymbolsTable(table_file)
示例#15
0
def test_add():
    st = SymbolsTable()
    st.add("<eps>", 0)
    assert len(st) == 1
    st.add("a", 1)
    assert len(st) == 2
示例#16
0
def test_add_valid_repeated():
    st = SymbolsTable()
    st.add("<eps>", 0)
    st.add("<eps>", 0)
    assert len(st) == 1
示例#17
0
 def test_add_valid_repeated(self):
     st = SymbolsTable()
     st.add("<eps>", 0)
     st.add("<eps>", 0)
     self.assertEqual(len(st), 1)
示例#18
0
 def test_add(self):
     st = SymbolsTable()
     st.add("<eps>", 0)
     self.assertEqual(len(st), 1)
     st.add("a", 1)
     self.assertEqual(len(st), 2)
示例#19
0
 def test_empty(self):
     st = SymbolsTable()
     self.assertEqual(len(st), 0)
示例#20
0
        default=0.5,
        help="Plot objects with this minimum relevance",
    )
    parser.add_argument(
        "index_type",
        choices=("position", "segment"),
        help="Type of the KWS index to process",
    )
    parser.add_argument("imgs_dir",
                        type=str,
                        help="Directory containing the indexed images")
    parser.add_argument("index",
                        type=argparse.FileType("r"),
                        help="File containing the KWS index")
    args = parser.parse_args()
    syms = SymbolsTable(args.symbols_table) if args.symbols_table else None

    batch_images = []

    for sample in args.index:
        m = re.match(r"^([^ ]+) +(.+)$", sample)
        sample_id = m.group(1)
        print(sample_id)
        matches = m.group(2).split(";")
        # Parse index entries
        if args.index_type == "position":
            matches = [parse_position_match(m, syms) for m in matches]
            matches = sorted(matches, key=lambda x: x.position)
        elif args.index_type == "segment":
            matches = [parse_segment_match(m, syms) for m in matches]
            matches = sorted(matches, key=lambda x: x.beg)
示例#21
0
def test_load_value_error(tmpdir):
    file = tmpdir / "f"
    file.write_text("\n\na   1\nb     c\n", "utf-8")
    with pytest.raises(ValueError):
        SymbolsTable(file)
示例#22
0
def test_empty():
    st = SymbolsTable()
    assert len(st) == 0
示例#23
0
                 default=16,
                 help='Average adaptive pooling of the images before the '
                 'LSTM layers')
    add_argument('--lstm_hidden_size', type=int, default=128)
    add_argument('--lstm_num_layers', type=int, default=1)
    add_argument('--add_softmax', action='store_true')
    add_argument('--add_boundary_blank', action='store_true')
    add_argument('syms', help='Symbols table mapping from strings to integers')
    add_argument('img_dir', help='Directory containing word images')
    add_argument('gt_file', help='')
    add_argument('checkpoint', help='')
    add_argument('output', type=argparse.FileType('w'))
    args = args()

    # Build neural network
    syms = SymbolsTable(args.syms)
    model = build_ctc_model(num_outputs=len(syms),
                            adaptive_pool_height=args.adaptive_pool_height,
                            lstm_hidden_size=args.lstm_hidden_size,
                            lstm_num_layers=args.lstm_num_layers)

    # Load checkpoint
    ckpt = torch.load(args.checkpoint)
    if 'model' in ckpt and 'optimizer' in ckpt:
        model.load_state_dict(ckpt['model'])
    else:
        model.load_state_dict(ckpt)

    # Ensure parameters are in the correct device
    model.eval()
    if args.gpu > 0:
示例#24
0
        params_to_optimize.append("acoustic_scale")
        space.append(
            hp.quniform(
                "acoustic_scale",
                args.acoustic_scale_min,
                args.acoustic_scale_max,
                args.acoustic_scale_quant,
            ))
        acoustic_scale_global = None
        acoustic_scale_key = len(space) - 1
    else:
        acoustic_scale_global = args.acoustic_scale_max
        acoustic_scale_key = None

    syms = [
        SymbolsTable(args.syms_pattern.format(cv=cv))
        for cv in range(args.num_partitions)
    ]

    @lru_cache(maxsize=None)
    def objective(params):
        if prior_scale_key is not None:
            prior_scale = params[prior_scale_key]
        else:
            prior_scale = prior_scale_global

        if acoustic_scale_key is not None:
            acoustic_scale = params[acoustic_scale_key]
        else:
            acoustic_scale = acoustic_scale_global