class CriticAtt(nn.Module):
    def __init__(self, state_dim, num_group, emb_dim=128, nhead=8):
        super(CriticAtt, self).__init__()
        self.emb = nn.Linear(state_dim, emb_dim)
        self.att_encoder = TransformerEncoderLayer(d_model=emb_dim,
                                                   nhead=nhead)
        self.linear_c1 = nn.Linear(emb_dim, 1)
        self.linear_c2 = nn.Linear(num_group, 1)

    def forward(self, state):
        """
        state: [batch_size x num_group x state_dim]

        """
        emb_state = self.emb(state)
        emb_state = emb_state.transpose(
            0, 1)  # [num_group x batch_size x state_dim]

        att_output = self.att_encoder(emb_state)
        att_output = att_output.transpose(
            0, 1)  # [batch_size x num_group x state_dim]
        att_output = self.att_encoder.dropout(
            self.att_encoder.activation(att_output))

        # output layer
        att_reduced = self.linear_c1(att_output).squeeze(-1)
        value = self.linear_c2(att_reduced).squeeze(-1)
        return value

        print('done')
class ActorAtt(nn.Module):
    def __init__(self, state_dim, emb_dim=128, nhead=8):
        super(ActorAtt, self).__init__()
        self.emb = nn.Linear(state_dim, emb_dim)
        self.att_encoder = TransformerEncoderLayer(d_model=emb_dim,
                                                   nhead=nhead)
        self.out_linear = nn.Linear(emb_dim, 1)

    def forward(self, state):
        """
        state: [batch_size x num_group x state_dim]

        """
        emb_state = self.emb(state)
        emb_state = emb_state.transpose(
            0, 1)  # [num_group x batch_size x state_dim]

        att_output = self.att_encoder(emb_state)
        att_output = att_output.transpose(
            0, 1)  # [batch_size x num_group x state_dim]
        att_output = self.att_encoder.dropout(
            self.att_encoder.activation(att_output))

        # output layer
        logits = self.out_linear(att_output).squeeze(-1)

        return logits

        print('done')