예제 #1
0
    def test_by_batch_restore_gaussian(self):
        """
        Test Gaussian HMM with spherical covariance by restoring an existing HMM via batch training.
        """
        config = {
            'num_states': 3,
            'output_dim': 2,
            'output_covariance': 'spherical'
        }

        hmm_reference = HMM(**config)

        hmm_reference.markov.initial_probs.set_(torch.as_tensor(
            [0.75, 0.25, 0], dtype=torch.float
        ))
        hmm_reference.markov.transition_probs.set_(torch.as_tensor([
            [0, 1, 0],
            [0.5, 0, 0.5],
            [1, 0, 0]
        ], dtype=torch.float))

        hmm_reference.emission.means.set_(torch.as_tensor([
            [-1, -1], [0, 1], [1, -1]
        ], dtype=torch.float))
        hmm_reference.emission.covars.set_(torch.as_tensor(
            [0.1, 0.25, 0.2]
        ))

        torch.manual_seed(42)
        sequences = hmm_reference.sample(8192, 8)

        hmm = HMM(**config)
        hmm.fit(sequences.chunk(32))

        order = hmm.markov.initial_probs.argsort(descending=True)

        self.assertTrue(torch.allclose(
            hmm.markov.initial_probs[order], hmm_reference.markov.initial_probs,
            atol=0.01, rtol=0
        ))

        self.assertTrue(torch.allclose(
            hmm.markov.transition_probs[order][:, order], hmm_reference.markov.transition_probs,
            atol=0.01, rtol=0
        ))

        self.assertTrue(torch.allclose(
            hmm.emission.means[order], hmm_reference.emission.means,
            atol=0.05, rtol=0
        ))

        self.assertTrue(torch.allclose(
            hmm.emission.covars[order], hmm_reference.emission.covars,
            atol=0.02, rtol=0
        ))
예제 #2
0
    def test_by_restore_discrete(self):
        """
        Test Discrete HMM with by restoring an existing HMM.
        """
        config = {
            'num_states': 3,
            'output': 'discrete',
            'output_num_states': 4
        }

        hmm_reference = HMM(**config)

        hmm_reference.markov.initial_probs.set_(torch.as_tensor(
            [0.75, 0.25, 0], dtype=torch.float
        ))
        hmm_reference.markov.transition_probs.set_(torch.as_tensor([
            [0.9, 0.1, 0],
            [0.5, 0, 0.5],
            [0, 1, 0]
        ], dtype=torch.float))

        hmm_reference.emission.probabilities.set_(torch.as_tensor([
            [1, 0, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 0.5, 0.5]
        ]))

        torch.manual_seed(21)
        sequences = hmm_reference.sample(16536, 16)

        hmm = HMM(**config)
        hmm.fit(sequences, epochs=100, eps=1e-7, patience=3)

        order = hmm.markov.initial_probs.argsort(descending=True)

        self.assertTrue(torch.allclose(
            hmm.markov.initial_probs[order], hmm_reference.markov.initial_probs,
            atol=0.01, rtol=0
        ))

        self.assertTrue(torch.allclose(
            hmm.markov.transition_probs[order][:, order], hmm_reference.markov.transition_probs,
            atol=0.01, rtol=0
        ))

        self.assertTrue(torch.allclose(
            hmm.emission.probabilities[order], hmm_reference.emission.probabilities,
            atol=0.05, rtol=0
        ))