def test_validate_spect_data_set_cuda(temp_dir): torch.manual_seed(29) feat_dir = os.path.join(temp_dir, "feat") ali_dir = os.path.join(temp_dir, "ali") ref_dir = os.path.join(temp_dir, "ref") feats_pt = os.path.join(feat_dir, "a.pt") ali_pt = os.path.join(ali_dir, "a.pt") ref_pt = os.path.join(ref_dir, "a.pt") os.makedirs(feat_dir) os.makedirs(ali_dir) os.makedirs(ref_dir) torch.save(torch.rand(10, 5), feats_pt) torch.save(torch.randint(10, (10, ), dtype=torch.long), ali_pt) torch.save(torch.tensor([1, 2, 3]), ref_pt) data_set = data.SpectDataSet(temp_dir) data.validate_spect_data_set(data_set) torch.save(torch.rand(10, 5).cuda(), feats_pt) with pytest.raises(ValueError, match="cuda"): data.validate_spect_data_set(data_set) with pytest.warns(UserWarning): data.validate_spect_data_set(data_set, True) # to CPU data.validate_spect_data_set(data_set) torch.save(torch.rand(10, 5).cuda(), feats_pt) torch.save(torch.randint(10, (10, ), dtype=torch.long).cuda(), ali_pt) torch.save(torch.tensor([1, 2, 3]).cuda(), ref_pt) with pytest.raises(ValueError, match="cuda"): data.validate_spect_data_set(data_set) with pytest.warns(UserWarning): data.validate_spect_data_set(data_set, True) # to CPU data.validate_spect_data_set(data_set)
def test_spect_data_set_warnings(temp_dir): torch.manual_seed(1) feat_dir = os.path.join(temp_dir, "feat") ali_dir = os.path.join(temp_dir, "ali") os.makedirs(feat_dir) os.makedirs(ali_dir) torch.save(torch.rand(3, 3), os.path.join(feat_dir, "a.pt")) torch.save(torch.rand(4, 3), os.path.join(feat_dir, "b.pt")) torch.save(torch.randint(10, (4, ), dtype=torch.long), os.path.join(ali_dir, "b.pt")) torch.save(torch.randint(10, (5, ), dtype=torch.long), os.path.join(ali_dir, "c.pt")) data_set = data.SpectDataSet(temp_dir, warn_on_missing=False) assert data_set.has_ali assert data_set.utt_ids == ("b", ) with pytest.warns(UserWarning) as warnings: data_set = data.SpectDataSet(temp_dir) assert len(warnings) == 2 assert any( str(x.message) == "Missing ali for uttid: 'a'" for x in warnings) assert any( str(x.message) == "Missing feat for uttid: 'c'" for x in warnings)
def test_spect_data_write_pdf(temp_dir, device): torch.manual_seed(1) feat_dir = os.path.join(temp_dir, "feat") os.makedirs(feat_dir) torch.save(torch.rand(3, 3), os.path.join(feat_dir, "a.pt")) data_set = data.SpectDataSet(temp_dir) z = torch.randint(10, (4, 5), dtype=torch.long) if device == "cuda": data_set.write_pdf("b", z.cuda()) else: data_set.write_pdf("b", z) zp = torch.load(os.path.join(temp_dir, "pdfs", "b.pt")) assert isinstance(zp, torch.FloatTensor) assert torch.allclose(zp, z.float()) data_set.write_pdf(0, torch.rand(10, 4)) assert os.path.exists(os.path.join(temp_dir, "pdfs", "a.pt")) data_set.write_pdf("c", z, pdfs_dir=os.path.join(temp_dir, "foop")) assert os.path.exists(os.path.join(temp_dir, "foop", "c.pt"))
def test_spect_data_write_hyp(temp_dir, device, sos, eos): torch.manual_seed(1) feat_dir = os.path.join(temp_dir, "feat") os.makedirs(feat_dir) torch.save(torch.rand(3, 3), os.path.join(feat_dir, "a.pt")) data_set = data.SpectDataSet(temp_dir, sos=sos, eos=eos) z = torch.randint(10, (4, 3), dtype=torch.float) zz = z if sos: zz = torch.cat([torch.full_like(zz, sos), zz]) if eos: zz = torch.cat([zz, torch.full_like(z, eos)]) if device == "cuda": data_set.write_hyp("b", zz.cuda()) else: data_set.write_hyp("b", zz) zp = torch.load(os.path.join(temp_dir, "hyp", "b.pt")) assert isinstance(zp, torch.LongTensor) assert torch.all(zp == z.long()) data_set.write_hyp(0, torch.randint(10, (11, 3))) assert os.path.exists(os.path.join(temp_dir, "hyp", "a.pt")) data_set.write_hyp("c", z, hyp_dir=os.path.join(temp_dir, "foop")) assert os.path.exists(os.path.join(temp_dir, "foop", "c.pt"))
def test_valid_spect_data_set(temp_dir, num_utts, file_prefix, populate_torch_dir, sos, eos, feat_dtype): feats, _, _, _, _, utt_ids = populate_torch_dir( temp_dir, num_utts, file_prefix=file_prefix, include_ali=False, include_ref=False, feat_dtype=feat_dtype, ) # note that this'll just resave the same features if there's no file # prefix. If there is, these ought to be ignored by the data set populate_torch_dir(temp_dir, num_utts, include_ali=False, include_ref=False, feat_dtype=feat_dtype) if not os.path.isdir(os.path.join(temp_dir, "feat", "fake")): os.makedirs(os.path.join(temp_dir, "feat", "fake")) torch.save( torch.randint(100, (10, 5), dtype=feat_dtype), os.path.join(temp_dir, "feat", "fake", file_prefix + "fake.pt"), ) data_set = data.SpectDataSet(temp_dir, file_prefix=file_prefix, eos=eos) assert not data_set.has_ali and not data_set.has_ref assert len(utt_ids) == len(data_set.utt_ids) assert all(utt_a == utt_b for (utt_a, utt_b) in zip(utt_ids, data_set.utt_ids)) assert all( ali_b is None and ref_b is None and torch.allclose(feat_a, feat_b) for (feat_a, (feat_b, ali_b, ref_b)) in zip(feats, data_set)) feats, alis, refs, _, _, utt_ids = populate_torch_dir( temp_dir, num_utts, file_prefix=file_prefix, feat_dtype=feat_dtype) if sos is not None: sos_sym = torch.full((3, ), -1, dtype=torch.long) sos_sym[0] = sos sos_sym = sos_sym.unsqueeze(0) refs = [torch.cat([sos_sym, x]) for x in refs] if eos is not None: eos_sym = torch.full((3, ), -1, dtype=torch.long) eos_sym[0] = eos eos_sym = eos_sym.unsqueeze(0) refs = [torch.cat([x, eos_sym]) for x in refs] data_set = data.SpectDataSet(temp_dir, file_prefix=file_prefix, sos=sos, eos=eos) assert data_set.has_ali and data_set.has_ref assert len(utt_ids) == len(data_set.utt_ids) assert all(utt_a == utt_b for (utt_a, utt_b) in zip(utt_ids, data_set.utt_ids)) assert all( torch.all(ali_a == ali_b) and torch.all(ref_a == ref_b) and feat_a.dtype == feat_b.dtype and torch.allclose(feat_a, feat_b) for ((feat_a, ali_a, ref_a), (feat_b, ali_b, ref_b)) in zip(zip(feats, alis, refs), data_set)) subset_ids = data_set.utt_ids[:num_utts // 2] data_set = data.SpectDataSet(temp_dir, file_prefix=file_prefix, subset_ids=set(subset_ids), sos=sos, eos=eos) assert all(utt_a == utt_b for (utt_a, utt_b) in zip(subset_ids, data_set.utt_ids)) assert all( torch.all(ali_a == ali_b) and torch.all( ref_a == ref_b) and torch.allclose(feat_a, feat_b) for ((feat_a, ali_a, ref_a), (feat_b, ali_b, ref_b)) in zip( zip(feats[:num_utts // 2], alis[:num_utts // 2], refs[:num_utts // 2]), data_set, ))
def test_spect_data_set_validity(temp_dir, eos): torch.manual_seed(1) feat_dir = os.path.join(temp_dir, "feat") ali_dir = os.path.join(temp_dir, "ali") ref_dir = os.path.join(temp_dir, "ref") feats_a_pt = os.path.join(feat_dir, "a.pt") feats_b_pt = os.path.join(feat_dir, "b.pt") ali_a_pt = os.path.join(ali_dir, "a.pt") ali_b_pt = os.path.join(ali_dir, "b.pt") ref_a_pt = os.path.join(ref_dir, "a.pt") ref_b_pt = os.path.join(ref_dir, "b.pt") os.makedirs(feat_dir) os.makedirs(ali_dir) os.makedirs(ref_dir) torch.save(torch.rand(10, 4), feats_a_pt) torch.save(torch.rand(4, 4), feats_b_pt) torch.save(torch.randint(10, (10, ), dtype=torch.long), ali_a_pt) torch.save(torch.randint(10, (4, ), dtype=torch.long), ali_b_pt) torch.save( torch.cat( [ torch.randint(10, (11, 1), dtype=torch.long), torch.full((11, 2), -1, dtype=torch.long), ], -1, ), ref_a_pt, ) torch.save(torch.tensor([[0, 3, 4], [1, 1, 2]]), ref_b_pt) data_set = data.SpectDataSet(temp_dir, eos=eos) data.validate_spect_data_set(data_set) torch.save(torch.rand(4, 4).long(), feats_b_pt) with pytest.raises(ValueError, match="not the same tensor type"): data.validate_spect_data_set(data_set) torch.save( torch.rand(4, ), feats_b_pt, ) with pytest.raises(ValueError, match="does not have two dimensions"): data.validate_spect_data_set(data_set) torch.save(torch.rand(4, 3), feats_b_pt) with pytest.raises(ValueError, match="has second dimension of size 3.*"): data.validate_spect_data_set(data_set) torch.save(torch.rand(4, 4), feats_b_pt) data.validate_spect_data_set(data_set) torch.save(torch.randint(10, (4, )).int(), ali_b_pt) with pytest.raises(ValueError, match="is not a long tensor"): data.validate_spect_data_set(data_set) with pytest.warns(UserWarning): data.validate_spect_data_set(data_set, True) # will fix bad type data.validate_spect_data_set(data_set) # fine after correction torch.save(torch.randint(10, (4, 1), dtype=torch.long), ali_b_pt) with pytest.raises(ValueError, match="does not have one dimension"): data.validate_spect_data_set(data_set) torch.save(torch.randint(10, (3, ), dtype=torch.long), ali_b_pt) with pytest.raises(ValueError, match="does not have the same first"): data.validate_spect_data_set(data_set) torch.save(torch.randint(10, (4, ), dtype=torch.long), ali_b_pt) data.validate_spect_data_set(data_set) torch.save(torch.Tensor([[0, 1, 2]]).int(), ref_b_pt) with pytest.raises(ValueError, match="is not a long tensor"): data.validate_spect_data_set(data_set) with pytest.warns(UserWarning): data.validate_spect_data_set(data_set, True) # convert to long data.validate_spect_data_set(data_set) torch.save(torch.tensor([[0, -1, 2], [1, 1, 2]]), ref_b_pt) with pytest.raises(ValueError, match="invalid boundaries"): data.validate_spect_data_set(data_set) with pytest.warns(UserWarning): data.validate_spect_data_set(data_set, True) # will remove end bound data.validate_spect_data_set(data_set) torch.save(torch.tensor([[0, 0, 1], [1, 3, 5]]), ref_b_pt) with pytest.raises(ValueError, match="invalid boundaries"): data.validate_spect_data_set(data_set) with pytest.warns(UserWarning): data.validate_spect_data_set(data_set, True) # will trim 5 to 4 data.validate_spect_data_set(data_set) torch.save(torch.tensor([[0, 0, 1], [1, 4, 5]]), ref_b_pt) with pytest.raises(ValueError, match="invalid boundaries"): data.validate_spect_data_set(data_set, True) # will not trim b/c causes s == e torch.save(torch.tensor([1, 2, 3]), ref_b_pt) with pytest.raises(ValueError, match="were 2D"): data.validate_spect_data_set(data_set) torch.save(torch.tensor([10, 4, 2, 5]), ref_a_pt) data.validate_spect_data_set(data_set)