def test_by_batch_restore_gaussian(self): """ Test Gaussian HMM with spherical covariance by restoring an existing HMM via batch training. """ config = { 'num_states': 3, 'output_dim': 2, 'output_covariance': 'spherical' } hmm_reference = HMM(**config) hmm_reference.markov.initial_probs.set_(torch.as_tensor( [0.75, 0.25, 0], dtype=torch.float )) hmm_reference.markov.transition_probs.set_(torch.as_tensor([ [0, 1, 0], [0.5, 0, 0.5], [1, 0, 0] ], dtype=torch.float)) hmm_reference.emission.means.set_(torch.as_tensor([ [-1, -1], [0, 1], [1, -1] ], dtype=torch.float)) hmm_reference.emission.covars.set_(torch.as_tensor( [0.1, 0.25, 0.2] )) torch.manual_seed(42) sequences = hmm_reference.sample(8192, 8) hmm = HMM(**config) hmm.fit(sequences.chunk(32)) order = hmm.markov.initial_probs.argsort(descending=True) self.assertTrue(torch.allclose( hmm.markov.initial_probs[order], hmm_reference.markov.initial_probs, atol=0.01, rtol=0 )) self.assertTrue(torch.allclose( hmm.markov.transition_probs[order][:, order], hmm_reference.markov.transition_probs, atol=0.01, rtol=0 )) self.assertTrue(torch.allclose( hmm.emission.means[order], hmm_reference.emission.means, atol=0.05, rtol=0 )) self.assertTrue(torch.allclose( hmm.emission.covars[order], hmm_reference.emission.covars, atol=0.02, rtol=0 ))
def test_by_restore_discrete(self): """ Test Discrete HMM with by restoring an existing HMM. """ config = { 'num_states': 3, 'output': 'discrete', 'output_num_states': 4 } hmm_reference = HMM(**config) hmm_reference.markov.initial_probs.set_(torch.as_tensor( [0.75, 0.25, 0], dtype=torch.float )) hmm_reference.markov.transition_probs.set_(torch.as_tensor([ [0.9, 0.1, 0], [0.5, 0, 0.5], [0, 1, 0] ], dtype=torch.float)) hmm_reference.emission.probabilities.set_(torch.as_tensor([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0.5, 0.5] ])) torch.manual_seed(21) sequences = hmm_reference.sample(16536, 16) hmm = HMM(**config) hmm.fit(sequences, epochs=100, eps=1e-7, patience=3) order = hmm.markov.initial_probs.argsort(descending=True) self.assertTrue(torch.allclose( hmm.markov.initial_probs[order], hmm_reference.markov.initial_probs, atol=0.01, rtol=0 )) self.assertTrue(torch.allclose( hmm.markov.transition_probs[order][:, order], hmm_reference.markov.transition_probs, atol=0.01, rtol=0 )) self.assertTrue(torch.allclose( hmm.emission.probabilities[order], hmm_reference.emission.probabilities, atol=0.05, rtol=0 ))