Пример #1
0
class QCritic(nn.Module, ABC):
    def __init__(self,
                 state_dim,
                 action_space,
                 hidden_dims,
                 activation='relu',
                 last_activation='Identity',
                 init_w=3e-3,
                 init_b=0.1,
                 use_multihead_output=False):
        super(QCritic, self).__init__()

        assert not use_multihead_output or action_space.__class__.__name__ == 'Discrete'

        if action_space.__class__.__name__ == 'Discrete':
            action_dim = action_space.n
        else:
            assert action_space.__class__.__name__ == 'Box'
            action_dim = action_space.shape[0]

        if use_multihead_output:
            action_dim = action_space.n
            self.critic = MLP(state_dim,
                              action_dim,
                              hidden_dims,
                              activation=activation,
                              last_activation=last_activation)
            self.forward = self._get_q_value_discrete
        else:
            self.critic = MLP(state_dim + action_dim,
                              1,
                              hidden_dims,
                              activation=activation,
                              last_activation=last_activation)
            self.forward = self._get_q_value_continuous

        def init_(m):
            init(m, fanin_init, lambda x: nn.init.constant_(x, init_b))

        def init_last_(m):
            init(m, lambda x: nn.init.uniform_(x, -init_w, init_w),
                 lambda x: nn.init.uniform_(x, -init_w, init_w))

        self.critic.init(init_, init_last_)

    def _get_q_value_continuous(self, state, action):
        return self.critic(torch.cat([state, action], dim=-1))

    def _get_q_value_discrete(self, state, action):
        return self.critic_feature(state)[action]
Пример #2
0
class RDynamics(BaseDynamics, ABC):
    def __init__(self,
                 state_dim: int,
                 action_dim: int,
                 reward_dim: int,
                 hidden_dims: List[int],
                 output_state_dim=None,
                 **kwargs):
        super(RDynamics, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.reward_dim = reward_dim
        self.output_state_dim = output_state_dim or state_dim

        assert getattr(kwargs, 'last_activation', 'identity') == 'identity'
        self.diff_dynamics = MLP(state_dim + action_dim,
                                 output_state_dim + reward_dim,
                                 hidden_dims,
                                 activation='swish',
                                 **kwargs)

        def init_(m):
            init(m, truncated_norm_init, lambda x: nn.init.constant_(x, 0))

        self.diff_dynamics.init(init_, init_)

    def forward(self, states, actions):
        x = torch.cat([states, actions], dim=-1)
        x = self.diff_dynamics(x)
        diff_states = x[..., :self.output_state_dim]
        rewards = x[..., self.output_state_dim:]
        return {'diff_states': diff_states, 'rewards': rewards}

    def predict(self, states, actions, **kwargs):
        diff_states, rewards = itemgetter('diff_states',
                                          'rewards')(self.forward(
                                              states, actions))
        return {'next_states': states + diff_states, 'rewards': rewards}

    def compute_l2_loss(self, l2_loss_coefs: Union[float, List[float]]):
        weight_norms = []
        for name, weight in self.diff_dynamics.named_parameters():
            if "weight" in name:
                weight_norms.append(weight.norm(2))
        weight_norms = torch.stack(weight_norms, dim=0)
        weight_decay = (
            torch.tensor(l2_loss_coefs, device=weight_norms.device) *
            weight_norms).sum()
        return weight_decay
Пример #3
0
class Actor(nn.Module, ABC):
    def __init__(self, state_dim: int, action_space, hidden_dims: List[int],
                 state_normalizer: Optional[nn.Module], use_limited_entropy=False, use_tanh_squash=False,
                 use_state_dependent_std=False, **kwargs):
        super(Actor, self).__init__()
        self.state_dim = state_dim
        self.action_space = action_space
        self.hidden_dims = hidden_dims
        self.use_limited_entropy = use_limited_entropy
        self.use_tanh_squash = use_tanh_squash

        if isinstance(action_space, Box) or isinstance(action_space, MultiBinary):
            self.action_dim = action_space.shape[0]
        else:
            assert isinstance(action_space, Discrete)
            self.action_dim = action_space.n

        mlp_kwargs = kwargs.copy()
        mlp_kwargs['activation'] = kwargs.get('activation', 'relu')
        mlp_kwargs['last_activation'] = kwargs.get('activation', 'relu')

        self.actor_feature = MLP(state_dim, hidden_dims[-1], hidden_dims[:-1], **mlp_kwargs)

        self.state_normalizer = state_normalizer or nn.Identity()

        self.actor_layer = TanhGaussainActorLayer(hidden_dims[-1], self.action_dim,
                                                  use_state_dependent_std)

        def init_(m): init(m, fanin_init, lambda x: nn.init.constant_(x, 0))
        self.actor_feature.init(init_, init_)

    def act(self, state, deterministic=False, reparameterize=False):
        action_feature = self.actor_feature(state)
        action_dist, action_means, action_logstds = self.actor_layer(action_feature)

        log_probs = None
        pretanh_actions = None

        if deterministic:
            actions = action_means
        else:
            if reparameterize:
                result = action_dist.rsample()
            else:
                result = action_dist.sample()
            actions, pretanh_actions = result
            log_probs = action_dist.log_probs(actions, pretanh_actions)

        entropy = action_dist.entropy().mean()

        return {'actions': actions, 'log_probs': log_probs, 'entropy': entropy,
                'action_means': action_means, 'action_logstds': action_logstds, 'pretanh_actions': pretanh_actions}

    def evaluate_actions(self, states, actions, pretanh_actions=None):
        states = self.state_normalizer(states)

        action_feature = self.actor_feature(states)
        action_dist, *_ = self.actor_layer(action_feature)

        if pretanh_actions:
            log_probs = action_dist.log_probs(actions, pretanh_actions)
        else:
            log_probs = action_dist.log_probs(actions)

        entropy = action_dist.entropy().mean()

        return {'log_probs': log_probs, 'entropy': entropy}