예제 #1
0
def test_global_mvn_initialization_and_stats_saving(global_mvn):
    with tempfile.NamedTemporaryFile() as tf:
        global_mvn.to_file(tf.name)
        global_mvn2 = GlobalMVN.from_file(tf.name)

    for key_item_1, key_item_2 in zip(global_mvn.state_dict().items(),
                                      global_mvn2.state_dict().items()):
        assert torch.equal(key_item_1[1], key_item_2[1])
예제 #2
0
def test_global_mvn_from_cuts():
    cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json")
    stats1 = GlobalMVN.from_cuts(cuts)
    stats2 = GlobalMVN.from_cuts(cuts, max_cuts=1)
    assert isinstance(stats1, GlobalMVN)
    assert isinstance(stats2, GlobalMVN)
예제 #3
0
def global_mvn():
    cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json")
    return GlobalMVN.from_cuts(cuts)