Beispiel #1
0
class SimpleRNN(PeGradNet):
    def __init__(self, input_size, hidden_size, num_classes, train_alg='batch'):
        super(type(self), self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = num_classes
        self.train_alg = train_alg

        self.rnn = RNNCell(input_size, hidden_size)
        self.fc = Linear(self.hidden_size, self.output_size)


    def forward(self, x):
        x = x.squeeze(1).permute(1, 0, 2)  # seq_len x batch_size x input_size

        self.rnn.reset_pgrad()

        hx = torch.zeros(x.shape[1], self.hidden_size, device=x.device)
        
        for t in range(x.shape[0]):
            hx = self.rnn(x[t], hx)

        logits = self.fc(hx)

        return logits

    def per_example_gradient(self, loss):
        grads = []

        pre_acts = self.rnn.pre_activation
        pre_acts.append(self.fc.pre_activation)
        
        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)

        grads.extend(self.rnn.per_example_gradient(Z_grad[:-1]))
        grads.extend(self.fc.per_example_gradient(Z_grad[-1]))

        return grads

    def pe_grad_norm(self, loss, batch_size, device):
        grad_norm = torch.zeros(batch_size, device=device, requires_grad=False)
        
        pre_acts = self.rnn.pre_activation
        pre_acts.append(self.fc.pre_activation)

        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)
        grad_norm.add_(self.rnn.pe_grad_sqnorm(Z_grad[:-1]))
        grad_norm.add_(self.fc.pe_grad_sqnorm(Z_grad[-1]))

        grad_norm.sqrt_()

        return grad_norm
Beispiel #2
0
class TransformerModel(nn.Module):
    def __init__(self, n_token, n_classes, d_model=512, n_layers=2,
                 n_head=8, n_hidden=2048, dropout=0.1, max_seq_len=512,
                 embeddings=None, train_alg='batch'):
        super(TransformerModel, self).__init__()

        self.train_alg = train_alg
        self.d_model = d_model
        self.n_head = n_head

        if embeddings is None:            
            self.token_embedding = nn.Embedding(n_token, d_model)
        else:
            self.token_embedding = nn.Embedding.from_pretrained(embeddings)
            self.token_embedding.weight.requires_grad = False

        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_len)        
        encoder_layers = TransformerEncoderLayer(d_model, n_head, n_hidden, dropout)
        # encoder_norm = nn.LayerNorm(d_model)
        encoder_norm = None
        self.encoder = TransformerEncoder(encoder_layers, n_layers, encoder_norm)
        self.fc= Linear(d_model, n_classes)

    def init_weights(self):
        initrange = 0.1
        self.token_embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        # positions = torch.arange(len(x), device=x.device).unsqueeze(-1)

        x = x.transpose(0, 1)
        # [sentence length, batch_size]
        x = self.token_embedding(x)
        # [sentence length, batch_size, embedding dim]
        x = self.pos_encoder(x)
        # x = x + self.pos_encoder(positions).expand_as(x)
        
        # [sentence length, batch_size, embedding dim]
        output = self.encoder(x)
        # [sentence length, batch_size, embedding dim]
        avg_out = output.transpose(0, 1).mean(dim=1)
        # [batch_size, embedding dim]
        preact = self.fc(avg_out)
        
        # [batch_size, num_classes]
        # return F.log_softmax(output, dim=-1)
        return preact

    def per_example_gradient(self, loss):
        grads = []
        pre_acts = []

        pre_acts.extend(self.encoder.collect_preactivations())
        pre_acts.append(self.fc.pre_activation)

        pre_acts = [m.pre_activ for m in modules]
        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)
        for m, zgrad in zip(modules, Z_grad):
            m.save_grad(zgrad)
        # loss.backward(retain_graph=True)        

        # TransformerEncoder
        grads.extend(self.encoder.per_example_gradient())

        # fully connected layer
        grads.extend(self.fc.per_example_gradient())

        return grads

    def pe_grad_norm(self, loss, batch_size, device):
        grad_norm = torch.zeros(batch_size, device=device)

        pre_acts = []
        pre_acts.extend(self.encoder.collect_preactivations())
        pre_acts.append(self.fc.pre_activation)

        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)

        grad_norm.add_(self.encoder.pe_grad_sqnorm(Z_grad[:-1]))
        grad_norm.add_(self.fc.pe_grad_sqnorm(Z_grad[-1]))        
        grad_norm.sqrt_()

        return grad_norm
Beispiel #3
0
class MultiheadAttention(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 dropout=0.,
                 bias=True,
                 add_bias_kv=False,
                 add_zero_attn=False):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.in_proj = Linear(embed_dim, 3 * embed_dim)
        # self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))

        # if bias:
        #     self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        # else:
        #     self.register_parameter('in_proj_bias', None)
        self.out_proj = Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.in_proj.weight)

        if self.in_proj.bias is not None:
            constant_(self.in_proj.bias, 0.)
            constant_(self.out_proj.bias, 0.)

    def forward(self,
                query,
                key,
                value,
                key_padding_mask=None,
                need_weights=True,
                attn_mask=None):
        attn_out, _ = multi_head_attention_forward(
            query,
            key,
            value,
            self.embed_dim,
            self.num_heads,
            self.in_proj,
            self.dropout,
            self.out_proj,
            training=self.training,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask)

        return attn_out

    def per_example_gradient(self, deriv_pre_activ_in, deriv_pre_activ_out):
        pe_grad_weight_in, pe_grad_bias_in = \
            self.in_proj.per_example_gradient(deriv_pre_activ_in)
        pe_grad_weight_out, pe_grad_bias_out = \
            self.out_proj.per_example_gradient(deriv_pre_activ_out)

        return (pe_grad_weight_in, pe_grad_bias_in, pe_grad_weight_out,
                pe_grad_bias_out)

    def pe_grad_sqnorm(self, deriv_pre_activ):
        grads = self.per_example_gradient(*deriv_pre_activ)
        batch_size = grads[0].size(0)

        grad_sq_norm = torch.zeros(batch_size, device=grads[0].device)
        for grad in grads:
            grad_sq_norm.add_(grad.pow(2).view(batch_size, -1).sum(1))

        return grad_sq_norm

    def collect_preactivations(self):
        return (self.in_proj.pre_activation, self.out_proj.pre_activation)