Example #1
0
    def test_frozen_inference(self):
        b_lstm = BayesianLSTM(1, 10)
        b_lstm.freeze = True

        in_data = torch.ones((10, 10, 1))
        b_inference_1 = b_lstm(in_data, hidden_states=None)
        b_inference_2 = b_lstm(in_data, hidden_states=None)

        self.assertEqual((b_inference_1[0] == b_inference_2[0]).all(),
                         torch.tensor(True))
Example #2
0
 def test_peephole_inference(self):
     b_lstm = BayesianLSTM(1, 10, peephole=True)
     in_data = torch.ones((10, 10, 1))
     b_lstm(in_data)
     b_lstm.freeze = True
     b_lstm(in_data)