def test_global_mvn_initialization_and_stats_saving(global_mvn): with tempfile.NamedTemporaryFile() as tf: global_mvn.to_file(tf.name) global_mvn2 = GlobalMVN.from_file(tf.name) for key_item_1, key_item_2 in zip(global_mvn.state_dict().items(), global_mvn2.state_dict().items()): assert torch.equal(key_item_1[1], key_item_2[1])
def test_global_mvn_from_cuts(): cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json") stats1 = GlobalMVN.from_cuts(cuts) stats2 = GlobalMVN.from_cuts(cuts, max_cuts=1) assert isinstance(stats1, GlobalMVN) assert isinstance(stats2, GlobalMVN)
def global_mvn(): cuts = CutSet.from_json("test/fixtures/ljspeech/cuts.json") return GlobalMVN.from_cuts(cuts)