class fofe_filter(nn.Module): def __init__(self, inplanes, alpha=0.8, length=3, inverse=False): super(fofe_filter, self).__init__() self.length = length self.channels = inplanes self.alpha = alpha self.fofe_filter = Parameter(torch.Tensor(inplanes, 1, length)) self.fofe_filter.requires_grad_(False) self._init_filter(alpha, length, inverse) self.padding = (length - 1) // 2 def _init_filter(self, alpha, length, inverse): if not inverse: self.fofe_filter[:, :, ].copy_( torch.pow(alpha, torch.linspace(length - 1, 0, length))) else: self.fofe_filter[:, :, ].copy_( torch.pow(alpha, torch.range(0, length - 1))) def fofe_encode(self, x): out = F.pad(x, (self.length - 1, 0), mode='constant', value=0) out = F.conv1d(out, self.fofe_filter, bias=None, stride=1, padding=0, groups=self.channels) return out def forward(self, x): if self.alpha == 1 or self.alpha == 0: return x x = self.fofe_encode(x) return x
class fofe_conv1d(nn.Module): def __init__(self, emb_dims, alpha=0.9, length=1, dilation=1, inverse=False): super(fofe_conv1d, self).__init__() self.alpha = alpha self.length = length self.channels = emb_dims self.fofe_filter = Parameter(torch.Tensor(emb_dims, 1, length)) self.fofe_filter.requires_grad_(False) self._init_filter(emb_dims, alpha, length, inverse) self.padding = (length - 1) // 2 self.dilated_conv = nn.Sequential( nn.Conv1d(self.channels, self.channels, 3, 1, padding=length, dilation=dilation, groups=1, bias=False), nn.LeakyReLU(0.1, inplace=True)) def _init_filter(self, channels, alpha, length, inverse): if not inverse: self.fofe_filter[:, :, ].copy_( torch.pow(self.alpha, torch.linspace(length - 1, 0, length))) else: self.fofe_filter[:, :, ].copy_( torch.pow(self.alpha, torch.range(0, length - 1))) def forward(self, x): x = torch.transpose(x, -2, -1) if (self.length % 2 == 0): x = F.pad(x, (0, 1), mode='constant', value=0) x = F.conv1d(x, self.fofe_filter, bias=None, stride=1, padding=self.padding, groups=self.channels) x = self.dilated_conv(x) return x
def __init__( self, modules: Sequence[Module], original: Union[Tensor, Parameter], unsafe: bool = False, ) -> None: # We require this because we need to treat differently the first parametrization # This should never throw, unless this class is used from the outside if len(modules) == 0: raise ValueError("ParametrizationList requires one or more modules.") super().__init__(modules) self.unsafe = unsafe # In plain words: # module.weight must keep its dtype and shape. # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, # this should be of the same dtype as the original tensor # # We check that the following invariants hold: # X = module.weight # Y = param.right_inverse(X) # assert isinstance(Y, Tensor) or # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) # Z = param(Y) if isisntance(Y, Tensor) else param(*Y) # # Consistency checks # assert X.dtype == Z.dtype and X.shape == Z.shape # # If it has one input, this allows to be able to use set_ to be able to # # move data to/from the original tensor without changing its id (which is what the # # optimiser uses to track parameters) # if isinstance(Y, Tensor) # assert X.dtype == Y.dtype # Below we use original = X, new = Y original_shape = original.shape original_dtype = original.dtype # Compute new with torch.no_grad(): new = original for module in reversed(self): # type: ignore[call-overload] if hasattr(module, "right_inverse"): try: new = module.right_inverse(new) except NotImplementedError: pass # else, or if it throws, we assume that right_inverse is the identity if not isinstance(new, Tensor) and not isinstance( new, collections.abc.Sequence ): raise ValueError( "'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " f"Got {type(new).__name__}" ) # Set the number of original tensors self.is_tensor = isinstance(new, Tensor) self.ntensors = 1 if self.is_tensor else len(new) # Register the tensor(s) if self.is_tensor: if original.dtype != new.dtype: raise ValueError( "When `right_inverse` outputs one tensor, it may not change the dtype.\n" f"original.dtype: {original.dtype}\n" f"right_inverse(original).dtype: {new.dtype}" ) # Set the original to original so that the user does not need to re-register the parameter # manually in the optimiser with torch.no_grad(): original.set_(new) # type: ignore[call-overload] _register_parameter_or_buffer(self, "original", original) else: for i, originali in enumerate(new): if not isinstance(originali, Tensor): raise ValueError( "'right_inverse' must return a Tensor or a Sequence of tensors " "(list, tuple...). " f"Got element {i} of the sequence with type {type(originali).__name__}." ) # If the original tensor was a Parameter that required grad, we expect the user to # add the new parameters to the optimizer after registering the parametrization # (this is documented) if isinstance(original, Parameter): originali = Parameter(originali) originali.requires_grad_(original.requires_grad) _register_parameter_or_buffer(self, f"original{i}", originali) if not self.unsafe: # Consistency checks: # Since f : A -> B, right_inverse : B -> A, Z and original should live in B # Z = forward(right_inverse(original)) Z = self() if not isinstance(Z, Tensor): raise ValueError( f"A parametrization must return a tensor. Got {type(Z).__name__}." ) if Z.dtype != original_dtype: raise ValueError( "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" f"unparametrized dtype: {original_dtype}\n" f"parametrized dtype: {Z.dtype}" ) if Z.shape != original_shape: raise ValueError( "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" f"unparametrized shape: {original_shape}\n" f"parametrized shape: {Z.shape}" )
class Convolution(nn.Module): r"""Performs a 2D convolution over an input spike-wave composed of several input planes. Current version only supports stride of 1 with no padding. The input is a 4D tensor with the size :math:`(T, C_{{in}}, H_{{in}}, W_{{in}})` and the crresponsing output is of size :math:`(T, C_{{out}}, H_{{out}}, W_{{out}})`, where :math:`T` is the number of time steps, :math:`C` is the number of feature maps (channels), and :math:`H`, and :math:`W` are the hight and width of the input/output planes. * :attr:`in_channels` controls the number of input planes (channels/feature maps). * :attr:`out_channels` controls the number of feature maps in the current layer. * :attr:`kernel_size` controls the size of the convolution kernel. It can be a single integer or a tuple of two integers. * :attr:`weight_mean` controls the mean of the normal distribution used for initial random weights. * :attr:`weight_std` controls the standard deviation of the normal distribution used for initial random weights. .. note:: Since this version of convolution does not support padding, it is the user responsibility to add proper padding on the input before applying convolution. Args: in_channels (int): Number of channels in the input. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. weight_mean (float, optional): Mean of the initial random weights. Default: 0.8 weight_std (float, optional): Standard deviation of the initial random weights. Default: 0.02 """ def __init__(self, in_channels, out_channels, kernel_size, weight_mean=0.8, weight_std=0.02): super(Convolution, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = to_pair(kernel_size) #self.weight_mean = weight_mean #self.weight_std = weight_std # For future use self.stride = 1 self.bias = None self.dilation = 1 self.groups = 1 self.padding = 0 # Parameters self.weight = Parameter( torch.Tensor(self.out_channels, self.in_channels, *self.kernel_size)) self.weight.requires_grad_(False) # We do not use gradients self.reset_weight(weight_mean, weight_std) def reset_weight(self, weight_mean=0.8, weight_std=0.02): """Resets weights to random values based on a normal distribution. Args: weight_mean (float, optional): Mean of the random weights. Default: 0.8 weight_std (float, optional): Standard deviation of the random weights. Default: 0.02 """ self.weight.normal_(weight_mean, weight_std) def load_weight(self, target): """Loads weights with the target tensor. Args: target (Tensor=): The target tensor. """ self.weight.copy_(target) def forward(self, input): return fn.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class ReLuTransformer(torch.nn.Module): __constants__ = ['bias', 'in_features', 'out_features'] def __init__(self, D_features,x_shape): super(ReLuTransformer, self).__init__() self.in_features = D_features self.out_features = D_features+x_shape self.slopes = Parameter(torch.Tensor(x_shape)) self.slopes.requires_grad_(True) self.x_shape=x_shape self.initialized=False #self.hyper_diag=torch.diagflat(torch.ones(np.prod(x_shape))).view((-1,) + x_shape) def reset_parameters(self,A_in): A = A_in[1].detach() a0 = A_in[0].detach() temp_slope = 0.5 * (1.0 + torch.div(a0, torch.abs(A).sum(axis=0))) temp_slope = torch.clamp(temp_slope,0,1) self.slopes.data=temp_slope def forward(self, A_in): if not self.initialized: ReLuTransformer.reset_parameters(self, A_in) self.initialized = True A = A_in[1] # eps x x_shape a0 = A_in[0] # x_shape l_x = -torch.abs(A).sum(axis=0) + a0 u_x = torch.abs(A).sum(axis=0) + a0 crossover_idx = torch.mul((l_x < 0).float(), (u_x > 0).float()) offset_factor = torch.max(u_x*(1 - self.slopes), -l_x*self.slopes) / 2. a0_n = torch.mul((1 - crossover_idx), torch.nn.functional.relu(a0)) \ + torch.mul(crossover_idx, torch.mul(self.slopes, a0) + offset_factor) A_new_e, A_0 = self.assemble_hyper_diagonal(self.x_shape, offset_factor, crossover_idx) A_n = torch.mul((l_x > 0).float().unsqueeze(0), torch.cat((A, A_0),dim=0))\ + torch.mul(crossover_idx.unsqueeze(0), torch.cat((torch.mul(self.slopes.unsqueeze(0), A), A_new_e), dim=0)) if DEBUG: l_x_n = -torch.abs(A_n).sum(axis=0) + a0_n u_x_n = torch.abs(A_n).sum(axis=0) + a0_n if not (((u_x_n-u_x)/abs((u_x_n-u_x)).median())>-1e-4).all() and (u_x_n-u_x).min()<-1e-5: print("ReLU upper bound soundness check failed") if not (((l_x_n-l_x)/abs((l_x_n-l_x)).median())>-1e-4).all() and (l_x_n-l_x).min()<-1e-5: print("ReLU lower bound soundness check failed") return (a0_n,A_n) def assemble_hyper_diagonal(self,x_shape,offset_factor,crossover_idx): # Build Tensor for new epsilons corresponding to a hyper diagonal in 4D (eps x channel x x_shape) or # 2D (eps x x_shape) with offset_factor on the entries indicated by crossover_idx '''A_new = torch.zeros((int(crossover_idx.sum()),) + x_shape, dtype=torch.float32) # Channels x space dim x eps space dim A_0=torch.clone(A_new) k_i = 0 if len(x_shape) > 1: for i in range(x_shape[0]): for j in range(x_shape[1]): for l in np.array(range(x_shape[2]))[crossover_idx[i,j,:]==1]: A_new[k_i, i, j, l] = offset_factor[i, j, l] k_i += 1 else: for i in np.array(range(x_shape[0]))[crossover_idx==1]: A_new[k_i, i] = offset_factor[i] k_i += 1 ''' # A_new=torch.diagflat(offset_factor)[crossover_idx.flatten() == 1, :].view((int(crossover_idx.sum()),) + x_shape) k=int(crossover_idx.sum()) A_new=torch.sparse.FloatTensor(torch.cat([torch.arange(k).view(-1, 1), crossover_idx.nonzero()], dim=1).T, offset_factor[crossover_idx == 1].flatten(), torch.Size((k,) + x_shape)).to_dense() '''if len(x_shape) > 1: A_new=torch.mul(self.hyper_diag[crossover_idx.flatten() == 1, :, :, :], offset_factor[crossover_idx == 1].flatten().view((k, 1, 1, 1)), ) else: A_new = torch.mul(self.hyper_diag[crossover_idx.flatten() == 1, :], offset_factor[crossover_idx == 1].flatten().view((k, 1)), ) ''' A_0=torch.zeros((int(crossover_idx.sum()),) + x_shape, dtype=torch.float32) return A_new, A_0