class VDropCentralData(nn.Module): """ Stores data for a set of variational dropout (VDrop) modules in large central tensors. The VDrop modules access the data using views. This makes it possible to operate on all of the data at once, (rather than e.g. 53 times with resnet50). Usage: 1. Instantiate 2. Pass into multiple constructed VDropLinear and VDropConv2d modules 3. Call finalize Before calling forward on the model, call "compute_forward_data". After calling forward on the model, call "clear_forward_data". The parameters are stored in terms of z_mu and z_var rather than w_mu and w_var to support group variational dropout (e.g. to allow for pruning entire channels.) """ def __init__(self, z_logvar_init=-10): super().__init__() self.z_chunk_sizes = [] self.z_logvar_init = z_logvar_init self.z_logvar_min = min(z_logvar_init, -10) self.z_logvar_max = 10. self.epsilon = 1e-8 self.data_views = {} self.modules = [] # Populated during register(), deleted during finalize() self.all_z_mu = [] self.all_z_logvar = [] self.all_num_weights = [] # Populated during finalize() self.z_mu = None self.z_logvar = None self.z_num_weights = None self.threshold = 3 def extra_repr(self): s = f"z_logvar_init={self.z_logvar_init}" return s def __getitem__(self, key): return self.data_views[key] def register(self, module, z_mu, z_logvar, num_weights_per_z=1): self.all_z_mu.append(z_mu.flatten()) self.all_z_logvar.append(z_logvar.flatten()) self.all_num_weights.append(num_weights_per_z) self.modules.append(module) data_index = len(self.z_chunk_sizes) self.z_chunk_sizes.append(z_mu.numel()) return data_index def finalize(self): self.z_mu = Parameter(torch.cat(self.all_z_mu)) self.z_logvar = Parameter(torch.cat(self.all_z_logvar)) self.z_num_weights = torch.tensor(self.all_num_weights, dtype=torch.float).repeat_interleave( torch.tensor(self.z_chunk_sizes)) del self.all_z_mu del self.all_z_logvar del self.all_num_weights def to(self, *args, **kwargs): ret = super().to(*args, **kwargs) self.z_num_weights = self.z_num_weights.to(*args, **kwargs) return ret def compute_forward_data(self): if self.training: self.data_views["z_mu"] = self.z_mu.split(self.z_chunk_sizes) self.data_views["z_var"] = self.z_logvar.exp().split( self.z_chunk_sizes) else: self.data_views["z_mu"] = ( self.z_mu * (self.compute_z_logalpha() < self.threshold).float()).split( self.z_chunk_sizes) def clear_forward_data(self): self.data_views.clear() def compute_z_logalpha(self): return self.z_logvar - (self.z_mu.square() + self.epsilon).log() def regularization(self): return (vdrop_regularization(self.compute_z_logalpha()) * self.z_num_weights).sum() def constrain_parameters(self): self.z_logvar.data.clamp_(min=self.z_logvar_min, max=self.z_logvar_max)
class VDropLinear2(nn.Module): """ A self-contained VDropLinear (doesn't use the VDropCentralData) """ def __init__(self, in_features, out_features, bias=True, w_logvar_init=-10): super().__init__() self.in_features = in_features self.out_features = out_features self.w_logvar_min = min(w_logvar_init, -10) self.w_logvar_max = 10. self.pruned_logvar_sentinel = self.w_logvar_max - 0.00058 self.epsilon = 1e-8 self.w_mu = Parameter(torch.Tensor(self.out_features, self.in_features)) self.w_logvar = Parameter( torch.Tensor(self.out_features, self.in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.bias = None self.w_logvar.data.fill_(w_logvar_init) # Standard nn.Linear initialization. init.kaiming_uniform_(self.w_mu, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.w_mu) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) self.tensor_constructor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) def extra_repr(self): s = f"{self.in_features}, {self.out_features}, " if self.bias is None: s += ", bias=False" return s def get_w_mu(self): return self.w_mu def get_w_var(self): return self.w_logvar.exp() def forward(self, x): if self.training: return vdrop_linear_forward(x, self.get_w_mu, self.get_w_var, self.bias, self.tensor_constructor) else: return F.linear(x, self.get_w_mu(), self.bias) def compute_w_logalpha(self): return self.w_logvar - (self.w_mu.square() + self.epsilon).log() def regularization(self): return vdrop_regularization(self.compute_w_logalpha()).sum() def constrain_parameters(self): self.w_logvar.data.clamp_(min=self.w_logvar_min, max=self.w_logvar_max)
class MaskedVDropConv2d(nn.Module): """ A self-contained masked Conv2d (doesn't use the VDropCentralData) """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, mask=None, w_logvar_init=-10): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = pair(kernel_size) self.stride = pair(stride) self.padding = pair(padding) self.dilation = pair(dilation) self.groups = groups self.w_logvar_min = min(w_logvar_init, -10) self.w_logvar_max = 10. self.pruned_logvar_sentinel = self.w_logvar_max - 0.00058 self.epsilon = 1e-8 self.w_mu = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) self.w_logvar = Parameter( torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.bias = None self.w_logvar.data.fill_(w_logvar_init) self.register_buffer( "w_mask", torch.HalfTensor(out_channels, in_channels // groups, *self.kernel_size)) # Standard nn.Conv2d initialization. init.kaiming_uniform_(self.w_mu, a=math.sqrt(5)) if mask is not None: self.w_mask[:] = mask self.w_mu.data *= self.w_mask self.w_logvar.data[self.w_mask == 0.0] = self.pruned_logvar_sentinel else: self.w_mask.fill_(1.0) # Standard nn.Conv2d initialization. if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.w_mu) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) self.tensor_constructor = (torch.FloatTensor if not torch.cuda.is_available() else torch.cuda.FloatTensor) def extra_repr(self): s = (f"{self.in_channels}, {self.out_channels}, " f"kernel_size={self.kernel_size}, stride={self.stride}") if self.padding != (0, ) * len(self.padding): s += f", padding={self.padding}" if self.dilation != (1, ) * len(self.dilation): s += f", dilation={self.dilation}" if self.groups != 1: s += f", groups={self.groups}" if self.bias is None: s += ", bias=False" return s def get_w_mu(self): return self.w_mu * self.w_mask def get_w_var(self): return self.w_logvar.exp() * self.w_mask def forward(self, x): if self.training: return vdrop_conv_forward(x, self.get_w_mu, self.get_w_var, self.bias, self.stride, self.padding, self.dilation, self.groups, self.tensor_constructor) else: return F.conv2d(x, self.get_w_mu(), self.bias, self.stride, self.padding, self.dilation, self.groups) def compute_w_logalpha(self): return self.w_logvar - (self.w_mu.square() + self.epsilon).log() def regularization(self): return (vdrop_regularization(self.compute_w_logalpha()) * self.w_mask).sum() def constrain_parameters(self): self.w_logvar.data.clamp_(min=self.w_logvar_min, max=self.w_logvar_max)
class Linear(nn.Module): __constants__ = ['in_features', 'out_features'] in_features: int out_features: int weight: Tensor def __init__(self, in_features: int, out_features: int, bias: bool = True, activation="ReLU", hidden_dim=None, hidden_activation="ReLU") -> None: super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.hidden_dim = hidden_dim self.hidden_activation = hidden_activation if hidden_dim is None: self.dims = vector(in_features, out_features) self.weight = Parameter(torch.zeros(out_features, in_features)) if bias: self.bias = Parameter(torch.zeros(out_features)) else: self.register_parameter('bias', None) self.activation = get_activation_layer(activation) else: self.dims = vector(in_features, *vector(hidden_dim), out_features) self.weight = nn.ParameterList(self.dims.map_k(lambda in_dim, out_dim: Parameter(torch.zeros(out_dim, in_dim)), 2)) if bias: self.bias = nn.ParameterList(self.dims.map_k(lambda in_dim, out_dim: Parameter(torch.zeros(out_dim)), 2)) else: self.register_parameter('bias', None) self.activation = vector(get_activation_layer(hidden_activation) for _ in range(len(hidden_dim))) self.activation.append(get_activation_layer(activation)) self.reset_parameters() def reset_parameters(self) -> None: if self.hidden_dim is None: if isinstance(self.activation, torch.nn.ReLU) or self.activation == torch.relu: init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='relu') else: init.xavier_normal_(self.weight) else: for a, w in zip(self.activation, self.weight): if isinstance(a, torch.nn.ReLU) or a == torch.relu: init.kaiming_normal_(w, a=0, mode='fan_in', nonlinearity='relu') else: init.xavier_normal_(w) def forward(self, input: Tensor) -> Tensor: if self.hidden_dim is None: if self.activation is None: return F.linear(input, self.weight, self.bias) else: return self.activation(F.linear(input, self.weight, self.bias)) else: h = input if self.bias is None: for w, a in zip(self.weight, self.activation): h = a(F.linear(h, w, None)) else: for w, b, a in zip(self.weight, self.bias, self.activation): h = a(F.linear(h, w, b)) return h def extra_repr(self) -> str: if self.activation is None: return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias is not None) elif isinstance(self.activation, vector): ret = 'in_features={}, out_features={}, bias={}, activation={}\n'.format(self.in_features, self.out_features, self.bias is not None, self.activation.map(lambda x: touch(lambda: x.__name__, str(x)))) ret += "{}".format(self.in_features) for d, a in zip(self.dims[1:], self.activation): ret += '->{}->{}'.format(d, touch(lambda: a.__name__, str(a))) return ret else: ret = 'in_features={}, out_features={}, bias={}, activation={}'.format(self.in_features, self.out_features, self.bias is not None, touch(lambda: self.activation.__name__, str(self.activation))) return ret def regulization_loss(self, p=2): if self.hidden_dim is None: if p == 2: return self.weight.square().sum() if p == 1: return self.weight.abs().sum() return (self.weight.abs() ** p).sum() else: reg = [] for w in self.weight: reg.append((w.abs() ** p).sum()) return sum(reg)