class QCritic(nn.Module, ABC): def __init__(self, state_dim, action_space, hidden_dims, activation='relu', last_activation='Identity', init_w=3e-3, init_b=0.1, use_multihead_output=False): super(QCritic, self).__init__() assert not use_multihead_output or action_space.__class__.__name__ == 'Discrete' if action_space.__class__.__name__ == 'Discrete': action_dim = action_space.n else: assert action_space.__class__.__name__ == 'Box' action_dim = action_space.shape[0] if use_multihead_output: action_dim = action_space.n self.critic = MLP(state_dim, action_dim, hidden_dims, activation=activation, last_activation=last_activation) self.forward = self._get_q_value_discrete else: self.critic = MLP(state_dim + action_dim, 1, hidden_dims, activation=activation, last_activation=last_activation) self.forward = self._get_q_value_continuous def init_(m): init(m, fanin_init, lambda x: nn.init.constant_(x, init_b)) def init_last_(m): init(m, lambda x: nn.init.uniform_(x, -init_w, init_w), lambda x: nn.init.uniform_(x, -init_w, init_w)) self.critic.init(init_, init_last_) def _get_q_value_continuous(self, state, action): return self.critic(torch.cat([state, action], dim=-1)) def _get_q_value_discrete(self, state, action): return self.critic_feature(state)[action]
class RDynamics(BaseDynamics, ABC): def __init__(self, state_dim: int, action_dim: int, reward_dim: int, hidden_dims: List[int], output_state_dim=None, **kwargs): super(RDynamics, self).__init__() self.state_dim = state_dim self.action_dim = action_dim self.reward_dim = reward_dim self.output_state_dim = output_state_dim or state_dim assert getattr(kwargs, 'last_activation', 'identity') == 'identity' self.diff_dynamics = MLP(state_dim + action_dim, output_state_dim + reward_dim, hidden_dims, activation='swish', **kwargs) def init_(m): init(m, truncated_norm_init, lambda x: nn.init.constant_(x, 0)) self.diff_dynamics.init(init_, init_) def forward(self, states, actions): x = torch.cat([states, actions], dim=-1) x = self.diff_dynamics(x) diff_states = x[..., :self.output_state_dim] rewards = x[..., self.output_state_dim:] return {'diff_states': diff_states, 'rewards': rewards} def predict(self, states, actions, **kwargs): diff_states, rewards = itemgetter('diff_states', 'rewards')(self.forward( states, actions)) return {'next_states': states + diff_states, 'rewards': rewards} def compute_l2_loss(self, l2_loss_coefs: Union[float, List[float]]): weight_norms = [] for name, weight in self.diff_dynamics.named_parameters(): if "weight" in name: weight_norms.append(weight.norm(2)) weight_norms = torch.stack(weight_norms, dim=0) weight_decay = ( torch.tensor(l2_loss_coefs, device=weight_norms.device) * weight_norms).sum() return weight_decay
class Actor(nn.Module, ABC): def __init__(self, state_dim: int, action_space, hidden_dims: List[int], state_normalizer: Optional[nn.Module], use_limited_entropy=False, use_tanh_squash=False, use_state_dependent_std=False, **kwargs): super(Actor, self).__init__() self.state_dim = state_dim self.action_space = action_space self.hidden_dims = hidden_dims self.use_limited_entropy = use_limited_entropy self.use_tanh_squash = use_tanh_squash if isinstance(action_space, Box) or isinstance(action_space, MultiBinary): self.action_dim = action_space.shape[0] else: assert isinstance(action_space, Discrete) self.action_dim = action_space.n mlp_kwargs = kwargs.copy() mlp_kwargs['activation'] = kwargs.get('activation', 'relu') mlp_kwargs['last_activation'] = kwargs.get('activation', 'relu') self.actor_feature = MLP(state_dim, hidden_dims[-1], hidden_dims[:-1], **mlp_kwargs) self.state_normalizer = state_normalizer or nn.Identity() self.actor_layer = TanhGaussainActorLayer(hidden_dims[-1], self.action_dim, use_state_dependent_std) def init_(m): init(m, fanin_init, lambda x: nn.init.constant_(x, 0)) self.actor_feature.init(init_, init_) def act(self, state, deterministic=False, reparameterize=False): action_feature = self.actor_feature(state) action_dist, action_means, action_logstds = self.actor_layer(action_feature) log_probs = None pretanh_actions = None if deterministic: actions = action_means else: if reparameterize: result = action_dist.rsample() else: result = action_dist.sample() actions, pretanh_actions = result log_probs = action_dist.log_probs(actions, pretanh_actions) entropy = action_dist.entropy().mean() return {'actions': actions, 'log_probs': log_probs, 'entropy': entropy, 'action_means': action_means, 'action_logstds': action_logstds, 'pretanh_actions': pretanh_actions} def evaluate_actions(self, states, actions, pretanh_actions=None): states = self.state_normalizer(states) action_feature = self.actor_feature(states) action_dist, *_ = self.actor_layer(action_feature) if pretanh_actions: log_probs = action_dist.log_probs(actions, pretanh_actions) else: log_probs = action_dist.log_probs(actions) entropy = action_dist.entropy().mean() return {'log_probs': log_probs, 'entropy': entropy}