def test_gumbel_trick(self):
        """
        We use a Gumbel noise which seems to be faster compared to using pytorch multinomial.
        Here we test that those are actually equivalent.
        """

        timing = Timing()

        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True

        with torch.no_grad():
            action_space = gym.spaces.Discrete(8)
            num_logits = calc_num_logits(action_space)
            device_type = 'cpu'
            device = torch.device(device_type)
            logits = torch.rand(self.batch_size, num_logits,
                                device=device) * 10.0 - 5.0

            if device_type == 'cuda':
                torch.cuda.synchronize(device)

            count_gumbel, count_multinomial = np.zeros(
                [action_space.n]), np.zeros([action_space.n])

            # estimate probability mass by actually sampling both ways
            num_samples = 20000

            action_distribution = get_action_distribution(action_space, logits)
            sample_actions_log_probs(action_distribution)
            action_distribution.sample_gumbel()

            with timing.add_time('gumbel'):
                for i in range(num_samples):
                    action_distribution = get_action_distribution(
                        action_space, logits)
                    samples_gumbel = action_distribution.sample_gumbel()
                    count_gumbel[samples_gumbel[0]] += 1

            action_distribution = get_action_distribution(action_space, logits)
            action_distribution.sample()

            with timing.add_time('multinomial'):
                for i in range(num_samples):
                    action_distribution = get_action_distribution(
                        action_space, logits)
                    samples_multinomial = action_distribution.sample()
                    count_multinomial[samples_multinomial[0]] += 1

            estimated_probs_gumbel = count_gumbel / float(num_samples)
            estimated_probs_multinomial = count_multinomial / float(
                num_samples)

            log.debug('Gumbel estimated probs: %r', estimated_probs_gumbel)
            log.debug('Multinomial estimated probs: %r',
                      estimated_probs_multinomial)
            log.debug('Sampling timing: %s', timing)
            time.sleep(0.1)  # to finish logging
Exemplo n.º 2
0
    def forward_tail(self, core_output, with_action_distribution=False):

        self.termination_prob = self.termination(core_output)
        self.termination_mask = torch.where(
            self.termination_prob > torch.rand_like(self.termination_prob),
            torch.ones(1, device=self.termination_prob.device),
            torch.zeros(1, device=self.termination_prob.device))

        values = self.critic_linear(core_output)
        action_distribution_params, action_distribution = self.action_parameterization(core_output)

        # for non-trivial action spaces it is faster to do these together
        actions, log_prob_actions = sample_actions_log_probs(action_distribution)

        # perhaps `action_logits` is not the best name here since we now support continuous actions
        result = AttrDict(
            dict(
                actions=actions,  # (B * O) x (num_actions/D)
                # B x num_action_logits x O -> (B * O) x num_action_logits
                action_logits=action_distribution_params.reshape(-1,
                                                                 action_distribution.num_actions),
                log_prob_actions=log_prob_actions,  # (B * O) x 1
                values=values,
                termination_prob=self.termination_prob,
                termination_mask=self.termination_mask,
            ))

        if with_action_distribution:
            result.action_distribution = action_distribution

        return result
Exemplo n.º 3
0
    def forward_tail(self, core_output, with_action_distribution=False):
        core_outputs = core_output.chunk(len(self.cores), dim=1)

        # first core output corresponds to the actor
        action_distribution_params, action_distribution = self.action_parameterization(
            core_outputs[0])
        # for non-trivial action spaces it is faster to do these together
        actions, log_prob_actions = sample_actions_log_probs(
            action_distribution)

        # second core output corresponds to the critic
        values = self.critic_linear(core_outputs[1])

        result = AttrDict(
            dict(
                actions=actions,
                action_logits=action_distribution_params,
                log_prob_actions=log_prob_actions,
                values=values,
            ))

        if with_action_distribution:
            result.action_distribution = action_distribution

        return result
Exemplo n.º 4
0
    def forward_tail(self, core_output, with_action_distribution=False):
        values = self.critic_linear(core_output)

        action_distribution_params, action_distribution = self.action_parameterization(core_output)

        # for non-trivial action spaces it is faster to do these together
        actions, log_prob_actions = sample_actions_log_probs(action_distribution)

        result = AttrDict(dict(
            actions=actions,
            action_logits=action_distribution_params,  # perhaps `action_logits` is not the best name here since we now support continuous actions
            log_prob_actions=log_prob_actions,
            values=values,
        ))

        if with_action_distribution:
            result.action_distribution = action_distribution

        return result