def aggregate_distributions_thompson(self, distrs):
     if len(distrs.size()) == 2:
         distrs_ = distrs.unsqueeze(0)  # functions expect a batch dimension
     thompson_probs = thompson_probabilities(distrs_)
     aggregated_distrs = aggregate_distributions_by_policy(
         distrs_, thompson_probs)
     return aggregated_distrs.squeeze(0)
 def select_action(self, state, greedy=False):
     qdist = self(torch.tensor(state, dtype=torch.float32))
     if greedy:
         exp_values = (qdist * self.distr_calc.bin_mids).sum(1)
         return torch.argmax(exp_values)
     else:
         thompson_probs = thompson_probabilities(
             qdist.unsqueeze(0)).squeeze(0)
         sample = torch.multinomial(thompson_probs, 1)
         # print(sample)
         return sample
Esempio n. 3
0
 def action_distrs_to_action_prob_logits(self, action_distrs, mask):
     # and calculate their log Thompson probabilities
     action_p_logits = torch.log(thompson_probabilities(action_distrs, mask) + 1e-5)
     action_p_logits[mask == 0] = -1e5
     return action_p_logits