class LassoFeatureSelection(nn.Module): """ 1 to 1 layer that should be used as the first layer in the network. Strong L1 regularization should enforce feature selection behavior. Does nothing if initialized with 0.0 lasso_value """ def __init__(self, input_size, lasso_value=0.0): super().__init__() self.lasso_value = lasso_value if self.lasso_value != 0: self.mul = Parameter(torch.ones(input_size)) def forward(self, x): if self.lasso_value != 0: return x * self.mul else: return x def loss(self): if self.lasso_value != 0: return self.mul.norm(1) * self.lasso_value else: return 0 def get_values(self): if self.lasso_value != 0: return self.mul.cpu().data.numpy() else: return 0
class GRUCell(nn.Module): def __init__(self, input_dim, hidden_dim, layer_norm=True): """ GRU cell class with layer normalization option. Parameters ---------- input_dim : int Dimensionality of GRU cell input. hidden_dim : int Dimensionality of GRU cell hidden state. layer_norm : bool Whether to use layer normalized version of GRU cell. """ super(GRUCell, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.layer_norm = layer_norm ''' W_iz = Parameter(torch.Tensor(input_dim, hidden_dim)) W_ir = Parameter(torch.Tensor(input_dim, hidden_dim)) self.W = Parameter(torch.Tensor(input_dim, hidden_dim)) # Xavier init input-to-hidden for w in [W_iz, W_ir, self.W]: init.xavier_uniform_(w.data) self.W_i = torch.cat([W_iz, W_ir], dim=1) ''' W_iz = torch.Tensor(input_dim, hidden_dim) W_ir = torch.Tensor(input_dim, hidden_dim) self.W = torch.Tensor(input_dim, hidden_dim) # Xavier init input-to-hidden for w in [W_iz, W_ir, self.W]: init.xavier_uniform_(w.data) self.W = Parameter(self.W) self.W_i = Parameter(torch.cat([W_iz, W_ir], dim=1)) ''' W_hz = Parameter(torch.Tensor(hidden_dim, hidden_dim)) W_hr = Parameter(torch.Tensor(hidden_dim, hidden_dim)) self.U = Parameter(torch.Tensor(hidden_dim, hidden_dim)) # Orthogonal init hidden-to-hidden for w in [W_hz, W_hr, self.U]: init.orthogonal_(w.data) self.W_h = torch.cat([W_hz, W_hr], dim=1) ''' W_hz = torch.Tensor(hidden_dim, hidden_dim) W_hr = torch.Tensor(hidden_dim, hidden_dim) self.U = torch.Tensor(hidden_dim, hidden_dim) # Orthogonal init hidden-to-hidden for w in [W_hz, W_hr, self.U]: init.orthogonal_(w.data) self.U = Parameter(self.U) self.W_h = Parameter(torch.cat([W_hz, W_hr], dim=1)) if self.layer_norm: self.ln1 = nn.LayerNorm(2 * hidden_dim) self.ln2 = nn.LayerNorm(2 * hidden_dim) self.ln3 = nn.LayerNorm(hidden_dim) self.ln4 = nn.LayerNorm(hidden_dim) def forward(self, x, h): if self.layer_norm: gates = self.ln1(torch.mm(h, self.W_h)) + \ self.ln2(torch.mm(x, self.W_i)) else: gates = torch.mm(h, self.W_h) + torch.mm(x, self.W_i) z, r = gates.chunk(2, dim=1) if np.isnan(self.W_h.cpu().detach().numpy()).any(): print(self.W_h) exit() if self.layer_norm: hh = torch.tanh( self.ln3(torch.mm(x, self.W)) + r.sigmoid() * self.ln4(torch.mm(h, self.U))) else: hh = torch.tanh( torch.mm(x, self.W) + r.sigmoid() * torch.mm(h, self.U)) h = (1 - z.sigmoid()) * h + z.sigmoid() * hh return h