def test_sanity(self): raw_logits = torch.tensor([[0.0, 1.0, 2.0]]) action_space = gym.spaces.Discrete(3) categorical = get_action_distribution(action_space, raw_logits) log_probs = categorical.log_prob(torch.tensor([0, 1, 2])) probs = torch.exp(log_probs) expected_probs = np.array( [0.09003057317038046, 0.24472847105479764, 0.6652409557748219]) self.assertTrue(np.allclose(probs.numpy(), expected_probs)) tuple_space = gym.spaces.Tuple([action_space, action_space]) raw_logits = torch.tensor([[0.0, 1.0, 2.0, 0.0, 1.0, 2.0]]) tuple_distr = get_action_distribution(tuple_space, raw_logits) for a1 in [0, 1, 2]: for a2 in [0, 1, 2]: action = torch.tensor([[a1, a2]]) log_prob = tuple_distr.log_prob(action) probability = torch.exp(log_prob)[0].item() self.assertAlmostEqual(probability, expected_probs[a1] * expected_probs[a2], delta=1e-6)
def test_gumbel_trick(self): """ We use a Gumbel noise which seems to be faster compared to using pytorch multinomial. Here we test that those are actually equivalent. """ timing = Timing() torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True with torch.no_grad(): action_space = gym.spaces.Discrete(8) num_logits = calc_num_logits(action_space) device_type = 'cpu' device = torch.device(device_type) logits = torch.rand(self.batch_size, num_logits, device=device) * 10.0 - 5.0 if device_type == 'cuda': torch.cuda.synchronize(device) count_gumbel, count_multinomial = np.zeros( [action_space.n]), np.zeros([action_space.n]) # estimate probability mass by actually sampling both ways num_samples = 20000 action_distribution = get_action_distribution(action_space, logits) sample_actions_log_probs(action_distribution) action_distribution.sample_gumbel() with timing.add_time('gumbel'): for i in range(num_samples): action_distribution = get_action_distribution( action_space, logits) samples_gumbel = action_distribution.sample_gumbel() count_gumbel[samples_gumbel[0]] += 1 action_distribution = get_action_distribution(action_space, logits) action_distribution.sample() with timing.add_time('multinomial'): for i in range(num_samples): action_distribution = get_action_distribution( action_space, logits) samples_multinomial = action_distribution.sample() count_multinomial[samples_multinomial[0]] += 1 estimated_probs_gumbel = count_gumbel / float(num_samples) estimated_probs_multinomial = count_multinomial / float( num_samples) log.debug('Gumbel estimated probs: %r', estimated_probs_gumbel) log.debug('Multinomial estimated probs: %r', estimated_probs_multinomial) log.debug('Sampling timing: %s', timing) time.sleep(0.1) # to finish logging
def forward(self, actor_core_output): """Just forward the FC layer and generate the distribution object.""" action_distribution_params = self.distribution_linear( actor_core_output) action_distribution = get_action_distribution( self.action_space, raw_logits=action_distribution_params) return action_distribution_params, action_distribution
def forward(self, actor_core_output): action_means = self.distribution_linear(actor_core_output) batch_size = action_means.shape[0] action_stddevs = self.learned_stddev.repeat(batch_size, 1) action_distribution_params = torch.cat((action_means, action_stddevs), dim=1) action_distribution = get_action_distribution( self.action_space, raw_logits=action_distribution_params) return action_distribution_params, action_distribution
def forward(self, actor_core_output): """Just forward the FC layer and generate the distribution object for all options""" action_distribution_params = self.distribution_linear( actor_core_output).view(-1, self.num_action_outputs, self.num_options) action_distribution = get_action_distribution( self.action_space, raw_logits=action_distribution_params, num_options=self.num_options) return action_distribution_params, action_distribution
def test_tuple_sanity_check(self): num_spaces, num_actions = 3, 2 simple_space = gym.spaces.Discrete(num_actions) spaces = [simple_space for _ in range(num_spaces)] tuple_space = gym.spaces.Tuple(spaces) self.assertTrue(calc_num_logits(tuple_space), num_spaces * num_actions) simple_logits = torch.zeros(1, num_actions) tuple_logits = torch.zeros(1, calc_num_logits(tuple_space)) simple_distr = get_action_distribution(simple_space, simple_logits) tuple_distr = get_action_distribution(tuple_space, tuple_logits) tuple_entropy = tuple_distr.entropy() self.assertEqual(tuple_entropy, simple_distr.entropy() * num_spaces) simple_logprob = simple_distr.log_prob(torch.ones(1)) tuple_logprob = tuple_distr.log_prob(torch.ones(1, num_spaces)) self.assertEqual(tuple_logprob, simple_logprob * num_spaces)
def test_simple_distribution(self): simple_action_space = gym.spaces.Discrete(3) simple_num_logits = calc_num_logits(simple_action_space) self.assertEqual(simple_num_logits, simple_action_space.n) simple_logits = torch.rand(self.batch_size, simple_num_logits) simple_action_distribution = get_action_distribution( simple_action_space, simple_logits) simple_actions = simple_action_distribution.sample() self.assertEqual(list(simple_actions.shape), [self.batch_size]) self.assertTrue( all(0 <= a < simple_action_space.n for a in simple_actions))
def test_tuple_distribution(self): num_spaces = random.randint(1, 4) spaces = [ gym.spaces.Discrete(random.randint(2, 5)) for _ in range(num_spaces) ] action_space = gym.spaces.Tuple(spaces) num_logits = calc_num_logits(action_space) logits = torch.rand(self.batch_size, num_logits) self.assertEqual(num_logits, sum(s.n for s in action_space.spaces)) action_distribution = get_action_distribution(action_space, logits) tuple_actions = action_distribution.sample() self.assertEqual(list(tuple_actions.shape), [self.batch_size, num_spaces]) log_probs = action_distribution.log_prob(tuple_actions) self.assertEqual(list(log_probs.shape), [self.batch_size]) entropy = action_distribution.entropy() self.assertEqual(list(entropy.shape), [self.batch_size])
def _record_summaries(self, train_loop_vars): var = train_loop_vars self.last_summary_time = time.time() stats = AttrDict() grad_norm = sum( p.grad.data.norm(2).item()**2 for p in self.actor_critic.parameters() if p.grad is not None)**0.5 stats.grad_norm = grad_norm stats.loss = var.loss stats.value = var.result.values.mean() stats.entropy = var.action_distribution.entropy().mean() stats.policy_loss = var.policy_loss stats.value_loss = var.value_loss stats.entropy_loss = var.entropy_loss stats.adv_min = var.adv.min() stats.adv_max = var.adv.max() stats.adv_std = var.adv_std stats.max_abs_logprob = torch.abs(var.mb.action_logits).max() if hasattr(var.action_distribution, 'summaries'): stats.update(var.action_distribution.summaries()) if var.epoch == self.cfg.ppo_epochs - 1 and var.batch_num == len( var.minibatches) - 1: # we collect these stats only for the last PPO batch, or every time if we're only doing one batch, IMPALA-style ratio_mean = torch.abs(1.0 - var.ratio).mean().detach() ratio_min = var.ratio.min().detach() ratio_max = var.ratio.max().detach() # log.debug('Learner %d ratio mean min max %.4f %.4f %.4f', self.policy_id, ratio_mean.cpu().item(), ratio_min.cpu().item(), ratio_max.cpu().item()) value_delta = torch.abs(var.values - var.old_values) value_delta_avg, value_delta_max = value_delta.mean( ), value_delta.max() # calculate KL-divergence with the behaviour policy action distribution old_action_distribution = get_action_distribution( self.actor_critic.action_space, var.mb.action_logits, ) kl_old = var.action_distribution.kl_divergence( old_action_distribution) kl_old_mean = kl_old.mean() stats.kl_divergence = kl_old_mean stats.value_delta = value_delta_avg stats.value_delta_max = value_delta_max stats.fraction_clipped = ( (var.ratio < var.clip_ratio_low).float() + (var.ratio > var.clip_ratio_high).float()).mean() stats.ratio_mean = ratio_mean stats.ratio_min = ratio_min stats.ratio_max = ratio_max stats.num_sgd_steps = var.num_sgd_steps # this caused numerical issues on some versions of PyTorch with second moment reaching infinity adam_max_second_moment = 0.0 for key, tensor_state in self.optimizer.state.items(): adam_max_second_moment = max( tensor_state['exp_avg_sq'].max().item(), adam_max_second_moment) stats.adam_max_second_moment = adam_max_second_moment version_diff = var.curr_policy_version - var.mb.policy_version stats.version_diff_avg = version_diff.mean() stats.version_diff_min = version_diff.min() stats.version_diff_max = version_diff.max() for key, value in stats.items(): stats[key] = to_scalar(value) return stats