def test_log_prob_gradient(self):
        """
        Same thing. Tanh term drops out since tanh has no params
        d/d mu log f_X(x) = - 2 (mu - x)
        d/d sigma log f_X(x) = 1/sigma^3 - 1/sigma
        :return:
        """
        mean_var = ptu.from_numpy(np.array([0]), requires_grad=True)
        std_var = ptu.from_numpy(np.array([0.25]), requires_grad=True)
        tanh_normal = TanhNormal(mean_var, std_var)
        z = ptu.from_numpy(np.array([1]))
        x = torch.tanh(z)
        log_prob = tanh_normal.log_prob(x, pre_tanh_value=z)

        gradient = ptu.from_numpy(np.array([1]))

        log_prob.backward(gradient)

        self.assertNpArraysEqual(
            ptu.get_numpy(mean_var.grad),
            np.array([16]),
        )
        self.assertNpArraysEqual(
            ptu.get_numpy(std_var.grad),
            np.array([4**3 - 4]),
        )
Example #2
0
    def log_prob(self, observation, state, actions=None):
        raw_actions = self.atanh(actions)
        tensor = observation
        # normalize
        # tensor = torch.div(tensor, 255.)
        for layer in self.convs:
            tensor = layer(tensor)
            tensor = self.conv_activation(tensor)

        tensor = batch_flatten(tensor)
        if state is not None:
            tensor = torch.cat((tensor, state), 1)
        i = 0
        for layer in self.fcs:
            h = tensor
            tensor = layer(tensor)
            if i == len(self.fcs) - 1:
                tensor = self.output_activation(tensor)
            else:
                tensor = self.fcs_activation(tensor)
            i += 1

        mean = tensor
        mean = torch.clamp(mean, LOG_MEAN_MIN, LOG_MEAN_MAX)
        log_std = self.last_fc_log_std(h)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        tanh_normal = TanhNormal(mean, std)
        log_prob = tanh_normal.log_prob(value=actions,
                                        pre_tanh_value=raw_actions)
        return log_prob.sum(-1)
    def test_log_prob_value(self):
        tanh_normal = TanhNormal(0, 1)
        z = np.array([1])
        x_np = np.tanh(z)
        x = ptu.from_numpy(x_np)
        log_prob = tanh_normal.log_prob(x)

        log_prob_np = ptu.get_numpy(log_prob)
        log_prob_expected = (
            np.log(np.array([1 / np.sqrt(2 * np.pi)])) - 0.5  # from Normal
            - np.log(1 - x_np**2))
        self.assertNpArraysEqual(
            log_prob_expected,
            log_prob_np,
        )
Example #4
0
    def log_prob_aviral(self, obs, actions):
        def atanh(x):
            one_plus_x = (1 + x).clamp(min=1e-6)
            one_minus_x = (1 - x).clamp(min=1e-6)
            return 0.5 * torch.log(one_plus_x / one_minus_x)

        raw_actions = atanh(actions)
        h = self.obs_processor(obs)
        h = self.mean_and_log_std_net(h)
        mean, log_std = torch.split(h, self.action_dim, dim=1)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        tanh_normal = TanhNormal(mean, std)
        log_prob = tanh_normal.log_prob(value=actions,
                                        pre_tanh_value=raw_actions)
        return log_prob.sum(-1)
Example #5
0
    def log_prob_aviral(self, obs, actions):
        def atanh(x):
            one_plus_x = (1 + x).clamp(min=1e-6)
            one_minus_x = (1 - x).clamp(min=1e-6)
            return 0.5 * torch.log(one_plus_x / one_minus_x)

        raw_actions = atanh(actions)
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        tanh_normal = TanhNormal(mean, std)
        log_prob = tanh_normal.log_prob(value=actions,
                                        pre_tanh_value=raw_actions)
        return log_prob.sum(-1)
Example #6
0
    def forward(
        self,
        observation,
        state,
        actions=None,
        reparameterize=True,
        deterministic=False,
        return_log_prob=False,
    ):
        # import ipdb; ipdb.set_trace()
        tensor = observation
        # normalize
        # tensor = torch.div(tensor, 255.)
        for layer in self.convs:
            tensor = layer(tensor)
            tensor = self.conv_activation(tensor)

        tensor = batch_flatten(tensor)
        if state is not None:
            tensor = torch.cat((tensor, state), 1)
        i = 0
        for layer in self.fcs:
            h = tensor
            tensor = layer(tensor)
            if i == len(self.fcs) - 1:
                tensor = self.output_activation(tensor)
            else:
                tensor = self.fcs_activation(tensor)
            i += 1

        mean = tensor
        log_std = self.last_fc_log_std(h)
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        mean = torch.clamp(mean, LOG_MEAN_MIN, LOG_MEAN_MAX)
        tanh_normal = TanhNormal(mean, std)

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None

        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                if reparameterize is True:
                    action, pre_tanh_value = tanh_normal.rsample(
                        return_pretanh_value=True)
                else:
                    action, pre_tanh_value = tanh_normal.sample(
                        return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()

        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
Example #7
0
    def forward(
        self,
        obs,
        deterministic=False,
        return_log_prob=False,
        return_entropy=False,
        return_log_prob_of_mean=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        :param return_entropy: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        :param return_log_prob_of_mean: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        """
        obs, taus = split_tau(obs)
        batch_size = obs.size()[0]
        tau_vector = Variable(
            torch.zeros((batch_size, self.tau_vector_len)) + taus.data)
        h = obs
        h1 = self.hidden_activation(self.first_input(h))
        h2 = self.hidden_activation(self.second_input(tau_vector))
        h = torch.cat((h1, h2), dim=1)
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                action, pre_tanh_value = tanh_normal.sample(
                    return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                action = tanh_normal.sample()

        if return_entropy:
            entropy = log_std + 0.5 + np.log(2 * np.pi) / 2
            # Because tanh is invertible, the entropy of a Gaussian and the
            # entropy of the tanh of a Gaussian is the same.
            entropy = entropy.sum(dim=1, keepdim=True)
        if return_log_prob_of_mean:
            tanh_normal = TanhNormal(mean, std)
            mean_action_log_prob = tanh_normal.log_prob(
                torch.tanh(mean),
                pre_tanh_value=mean,
            )
            mean_action_log_prob = mean_action_log_prob.sum(dim=1,
                                                            keepdim=True)
        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
Example #8
0
    def forward(
        self,
        obs,
        reparameterize=True,
        deterministic=False,
        return_log_prob=False,
        return_entropy=False,
        return_log_prob_of_mean=False,
    ):
        """
        :param obs: Observation
        :param deterministic: If True, do not sample
        :param return_log_prob: If True, return a sample and its log probability
        :param return_entropy: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        :param return_log_prob_of_mean: If True, return the true expected log
        prob. Will not need to be differentiated through, so this can be a
        number.
        """
        h = obs
        for i, fc in enumerate(self.fcs):
            h = self.hidden_activation(fc(h))
        mean = self.last_fc(h)
        if self.std is None:
            log_std = self.last_fc_log_std(h)
            log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
            std = torch.exp(log_std)
        else:
            std = self.std
            log_std = self.log_std

        log_prob = None
        entropy = None
        mean_action_log_prob = None
        pre_tanh_value = None
        if deterministic:
            action = torch.tanh(mean)
        else:
            tanh_normal = TanhNormal(mean, std)
            if return_log_prob:
                if reparameterize is True:
                    action, pre_tanh_value = tanh_normal.rsample(
                        return_pretanh_value=True)
                else:
                    action, pre_tanh_value = tanh_normal.sample(
                        return_pretanh_value=True)
                log_prob = tanh_normal.log_prob(action,
                                                pre_tanh_value=pre_tanh_value)
                log_prob = log_prob.sum(dim=1, keepdim=True)
            else:
                if reparameterize is True:
                    action = tanh_normal.rsample()
                else:
                    action = tanh_normal.sample()

        if return_entropy:
            entropy = log_std + 0.5 + np.log(2 * np.pi) / 2
            # I'm not sure how to compute the (differential) entropy for a
            # tanh(Gaussian)
            entropy = entropy.sum(dim=1, keepdim=True)
            raise NotImplementedError()
        if return_log_prob_of_mean:
            tanh_normal = TanhNormal(mean, std)
            mean_action_log_prob = tanh_normal.log_prob(
                torch.tanh(mean),
                pre_tanh_value=mean,
            )
            mean_action_log_prob = mean_action_log_prob.sum(dim=1,
                                                            keepdim=True)
        return (
            action,
            mean,
            log_std,
            log_prob,
            entropy,
            std,
            mean_action_log_prob,
            pre_tanh_value,
        )
Example #9
0
 def logprob(self, action, mean, std):
     # import ipdb; ipdb.set_trace()
     tanh_normal = TanhNormal(mean, std)
     log_prob = tanh_normal.log_prob(action, )
     log_prob = log_prob.sum(dim=1, keepdim=True)
     return log_prob
 def test_log_prob_type(self):
     tanh_normal = TanhNormal(0, 1)
     x = ptu.from_numpy(np.array([0]))
     log_prob = tanh_normal.log_prob(x)
     self.assertIsInstance(log_prob, torch.autograd.Variable)