class TestMultiOneHotCategorical(unittest.TestCase): def setUp(self) -> None: self.test_probs = torch.tensor([[0.3, 0.2, 0.4, 0.1, 0.25, 0.5, 0.25, 0.3, 0.4, 0.1, 0.1, 0.1], [0.2, 0.3, 0.1, 0.4, 0.5, 0.3, 0.2, 0.2, 0.3, 0.2, 0.2, 0.1]]) self.test_sections = (4, 3, 5) self.test_actions = torch.tensor([[0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0.]]).long() self.test_sected_actions = torch.split(self.test_actions, self.test_sections, dim=-1) self.test_multi_onehot_categorical = MultiOneHotCategorical(self.test_probs, self.test_sections) self.test_onehot_categorical1 = OneHotCategorical(self.test_probs[:, :4]) self.test_onehot_categorical2 = OneHotCategorical(self.test_probs[:, 4:7]) self.test_onehot_categorical3 = OneHotCategorical(self.test_probs[:, 7:]) def test_log_prob(self): test_cat1_log_prob = self.test_onehot_categorical1.log_prob(self.test_sected_actions[0]) test_cat2_log_prob = self.test_onehot_categorical2.log_prob(self.test_sected_actions[1]) test_cat3_log_prob = self.test_onehot_categorical3.log_prob(self.test_sected_actions[2]) test_multi_cat_log_prob = self.test_multi_onehot_categorical.log_prob(self.test_actions) print(test_multi_cat_log_prob) print(test_cat1_log_prob) self.assertEqual(test_cat1_log_prob.shape, test_multi_cat_log_prob.shape) self.assertTrue( torch.equal(test_cat1_log_prob + test_cat2_log_prob + test_cat3_log_prob, test_multi_cat_log_prob)) def test_sample(self): test_cat1_sample = self.test_onehot_categorical1.sample() test_cat2_sample = self.test_onehot_categorical2.sample() test_cat3_sample = self.test_onehot_categorical3.sample() test_cat_sample = torch.cat([test_cat1_sample, test_cat2_sample, test_cat3_sample], dim=-1) test_multi_cat_sample = self.test_multi_onehot_categorical.sample() self.assertEqual(test_cat_sample.shape, test_multi_cat_sample.shape) self.assertTrue(torch.equal(test_cat_sample.sum(dim=-1), test_multi_cat_sample.sum(dim=-1))) def test_entropy(self): test_cat1_entropy = self.test_onehot_categorical1.entropy() test_cat2_entropy = self.test_onehot_categorical2.entropy() test_cat3_entropy = self.test_onehot_categorical3.entropy() test_multi_cat_entropy = self.test_multi_onehot_categorical.entropy() self.assertTrue(torch.equal(test_cat1_entropy + test_cat2_entropy + test_cat3_entropy, test_multi_cat_entropy), "Expected same entropy!!!")
def forward(self, x): # For convenience we use torch.distributions to sample and compute the values of interest for the distribution see (https://pytorch.org/docs/stable/distributions.html) for more details. probs = self.encode(x.view(-1, 784)) m = OneHotCategorical(probs) action = m.sample() log_prob = m.log_prob(action) entropy = m.entropy() return self.decode(action), log_prob, entropy