def test_frozen_inference(self): b_lstm = BayesianLSTM(1, 10) b_lstm.freeze = True in_data = torch.ones((10, 10, 1)) b_inference_1 = b_lstm(in_data, hidden_states=None) b_inference_2 = b_lstm(in_data, hidden_states=None) self.assertEqual((b_inference_1[0] == b_inference_2[0]).all(), torch.tensor(True))
def test_peephole_inference(self): b_lstm = BayesianLSTM(1, 10, peephole=True) in_data = torch.ones((10, 10, 1)) b_lstm(in_data) b_lstm.freeze = True b_lstm(in_data)