Beispiel #1
0
class SimpleRNN(PeGradNet):
    def __init__(self, input_size, hidden_size, num_classes, train_alg='batch'):
        super(type(self), self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = num_classes
        self.train_alg = train_alg

        self.rnn = RNNCell(input_size, hidden_size)
        self.fc = Linear(self.hidden_size, self.output_size)


    def forward(self, x):
        x = x.squeeze(1).permute(1, 0, 2)  # seq_len x batch_size x input_size

        self.rnn.reset_pgrad()

        hx = torch.zeros(x.shape[1], self.hidden_size, device=x.device)
        
        for t in range(x.shape[0]):
            hx = self.rnn(x[t], hx)

        logits = self.fc(hx)

        return logits

    def per_example_gradient(self, loss):
        grads = []

        pre_acts = self.rnn.pre_activation
        pre_acts.append(self.fc.pre_activation)
        
        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)

        grads.extend(self.rnn.per_example_gradient(Z_grad[:-1]))
        grads.extend(self.fc.per_example_gradient(Z_grad[-1]))

        return grads

    def pe_grad_norm(self, loss, batch_size, device):
        grad_norm = torch.zeros(batch_size, device=device, requires_grad=False)
        
        pre_acts = self.rnn.pre_activation
        pre_acts.append(self.fc.pre_activation)

        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)
        grad_norm.add_(self.rnn.pe_grad_sqnorm(Z_grad[:-1]))
        grad_norm.add_(self.fc.pe_grad_sqnorm(Z_grad[-1]))

        grad_norm.sqrt_()

        return grad_norm
Beispiel #2
0
class SimpleLSTM(PeGradNet):
    def __init__(self, input_size, hidden_size, output_size, train_alg='batch'):
        super(SimpleLSTM, self).__init__()
        
        self.lstm = LSTMCell(input_size, hidden_size)
        self.fc = Linear(hidden_size, output_size)
        self.train_alg = train_alg

    def forward(self, x, init_states=None):
        # x = x.squeeze(1)
        batch_size = x.shape[0]
        x = x.reshape(batch_size, x.shape[2], -1)
        seq_size = x.shape[1]
        hidden_size = self.lstm.hidden_size

        self.lstm.reset_pgrad()

        if init_states is None:
            h_t, c_t = (torch.zeros(batch_size, hidden_size, device=x.device), 
                        torch.zeros(batch_size, hidden_size, device=x.device))
        else:
            h_t, c_t = init_states
         
        for t in range(seq_size):
            x_t = x[:, t, :]
            h_t, c_t = self.lstm(x_t, h_t, c_t)

        logits = self.fc(h_t)

        return logits

    def pe_grad_norm(self, loss, batch_size, device):
        grad_norm = torch.zeros(batch_size, device=device)

        pre_acts = self.lstm.pre_activation
        pre_acts.append(self.fc.pre_activation)
        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)

        grad_norm.add_(self.lstm.pe_grad_sqnorm(Z_grad[:-1]))
        grad_norm.add_(self.fc.pe_grad_sqnorm(Z_grad[-1]))
            
        grad_norm.sqrt_()

        return grad_norm
Beispiel #3
0
class TransformerModel(nn.Module):
    def __init__(self, n_token, n_classes, d_model=512, n_layers=2,
                 n_head=8, n_hidden=2048, dropout=0.1, max_seq_len=512,
                 embeddings=None, train_alg='batch'):
        super(TransformerModel, self).__init__()

        self.train_alg = train_alg
        self.d_model = d_model
        self.n_head = n_head

        if embeddings is None:            
            self.token_embedding = nn.Embedding(n_token, d_model)
        else:
            self.token_embedding = nn.Embedding.from_pretrained(embeddings)
            self.token_embedding.weight.requires_grad = False

        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_len)        
        encoder_layers = TransformerEncoderLayer(d_model, n_head, n_hidden, dropout)
        # encoder_norm = nn.LayerNorm(d_model)
        encoder_norm = None
        self.encoder = TransformerEncoder(encoder_layers, n_layers, encoder_norm)
        self.fc= Linear(d_model, n_classes)

    def init_weights(self):
        initrange = 0.1
        self.token_embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        # positions = torch.arange(len(x), device=x.device).unsqueeze(-1)

        x = x.transpose(0, 1)
        # [sentence length, batch_size]
        x = self.token_embedding(x)
        # [sentence length, batch_size, embedding dim]
        x = self.pos_encoder(x)
        # x = x + self.pos_encoder(positions).expand_as(x)
        
        # [sentence length, batch_size, embedding dim]
        output = self.encoder(x)
        # [sentence length, batch_size, embedding dim]
        avg_out = output.transpose(0, 1).mean(dim=1)
        # [batch_size, embedding dim]
        preact = self.fc(avg_out)
        
        # [batch_size, num_classes]
        # return F.log_softmax(output, dim=-1)
        return preact

    def per_example_gradient(self, loss):
        grads = []
        pre_acts = []

        pre_acts.extend(self.encoder.collect_preactivations())
        pre_acts.append(self.fc.pre_activation)

        pre_acts = [m.pre_activ for m in modules]
        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)
        for m, zgrad in zip(modules, Z_grad):
            m.save_grad(zgrad)
        # loss.backward(retain_graph=True)        

        # TransformerEncoder
        grads.extend(self.encoder.per_example_gradient())

        # fully connected layer
        grads.extend(self.fc.per_example_gradient())

        return grads

    def pe_grad_norm(self, loss, batch_size, device):
        grad_norm = torch.zeros(batch_size, device=device)

        pre_acts = []
        pre_acts.extend(self.encoder.collect_preactivations())
        pre_acts.append(self.fc.pre_activation)

        Z_grad = torch.autograd.grad(loss, pre_acts, retain_graph=True)

        grad_norm.add_(self.encoder.pe_grad_sqnorm(Z_grad[:-1]))
        grad_norm.add_(self.fc.pe_grad_sqnorm(Z_grad[-1]))        
        grad_norm.sqrt_()

        return grad_norm
Beispiel #4
0
class TransformerEncoderLayer(nn.Module):
    def __init__(self,
                 d_model,
                 nhead,
                 dim_feedforward=2048,
                 dropout=0.1,
                 activation="relu",
                 pe_grad=True):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)

        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

        self._pe_modules = [
            self.self_attn, self.linear1, self.linear2, self.norm1, self.norm2
        ]

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src = self.self_attn(src,
                             src,
                             src,
                             attn_mask=src_mask,
                             key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src)
        src = self.norm1(src)

        if hasattr(self, "activation"):
            src = self.linear2(self.dropout(self.activation(
                self.linear1(src))))
        else:  # for backward compatibility
            src = self.linear2(self.dropout(F.relu(self.linear1(src))))
        out = src + self.dropout2(src)
        out = self.norm2(out)

        return out

    def per_example_gradient(self):
        grads = []

        for m in self._pe_modules:
            grads.extend(m.per_example_gradient())

        return grads

    def pe_grad_sqnorm(self, deriv_pre_activ):
        batch_size = deriv_pre_activ[0].size(1)
        device = deriv_pre_activ[0].device
        grad_sq_norm = torch.zeros(batch_size, device=device)

        grad_sq_norm.add_(self.self_attn.pe_grad_sqnorm(deriv_pre_activ[:2]))
        grad_sq_norm.add_(self.linear1.pe_grad_sqnorm(deriv_pre_activ[2]))
        grad_sq_norm.add_(self.linear2.pe_grad_sqnorm(deriv_pre_activ[3]))
        grad_sq_norm.add_(self.norm1.pe_grad_sqnorm(deriv_pre_activ[4]))
        grad_sq_norm.add_(self.norm2.pe_grad_sqnorm(deriv_pre_activ[5]))

        return grad_sq_norm

    def collect_preactivations(self):
        pre_acts = []

        pre_acts.extend(self.self_attn.collect_preactivations())
        pre_acts.append(self.linear1.pre_activation)
        pre_acts.append(self.linear2.pre_activation)
        pre_acts.append(self.norm1.pre_activation)
        pre_acts.append(self.norm2.pre_activation)

        return pre_acts