def test_bernoulli_kl(): """ distribution test """ logits_0 = torch.from_numpy(np.random.randn(100, 8, 5)) logits_1 = torch.from_numpy(np.random.randn(100, 8, 5)) dist_0 = CategoricalProbabilityDistribution(logits=logits_0, action_space=spaces.Discrete(5), temperature=1.0) dist_1 = CategoricalProbabilityDistribution(logits=logits_1, action_space=spaces.Discrete(5), temperature=1.0) assert dist_0.kl(dist_1).numpy().ndim == 2 assert dist_0.kl(dist_1).numpy().shape == (100, 8)
def test_categorical_entropy(): """ distribution test """ logits = torch.from_numpy(np.random.randn(5)) dist = CategoricalProbabilityDistribution(logits=logits, action_space=spaces.Discrete(5), temperature=1.0) assert dist.entropy().numpy().ndim == 0 logits = torch.from_numpy(np.random.randn(100, 5)) dist = CategoricalProbabilityDistribution(logits=logits, action_space=spaces.Discrete(5), temperature=1.0) assert dist.entropy().numpy().ndim == 1 assert dist.entropy().numpy().shape == (100,) logits = torch.from_numpy(np.random.randn(100, 8, 5)) dist = CategoricalProbabilityDistribution(logits=logits, action_space=spaces.Discrete(5), temperature=1.0) assert dist.entropy().numpy().ndim == 2 assert dist.entropy().numpy().shape == (100, 8)
def __init__(self, logits: torch.Tensor, action_space: spaces.MultiDiscrete, temperature: float): # instantiate categorical sub-distributions self.sub_distributions = [] i0 = 0 for i, n in enumerate(action_space.nvec): sub_distribution = CategoricalProbabilityDistribution( logits=logits[..., i0:i0 + n], action_space=spaces.Discrete(action_space.nvec[i]), temperature=temperature) self.sub_distributions.append(sub_distribution) # shift logits starting index i0 += n
def get_dict_distribution(): """ distribution test """ logits_0 = torch.from_numpy(np.random.randn(100, 3)) logits_1 = torch.from_numpy(np.random.randn(100, 5)) distribution_dict = { "action_0": CategoricalProbabilityDistribution(logits=logits_0, action_space=spaces.Discrete(9), temperature=1.0), "action_1": BernoulliProbabilityDistribution(logits=logits_1, action_space=spaces.MultiBinary(5), temperature=1.0) } return DictProbabilityDistribution(distribution_dict=distribution_dict)
def test_categorical_required_logits_shape(): """ distribution test """ shape = CategoricalProbabilityDistribution.required_logits_shape(action_space=spaces.Discrete(5)) assert shape == [5]