class RandomProjection(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features), requires_grad=False) # fix weights self.reset_parameters() def reset_parameters(self): # experimentally: std=1 appears to affect scale too much self.weight.normal_(std=0.1) # other init option: set randomly to 1 or -1 # self.weight.bernoulli_(self.weight.fill_(0.5)).mul_(2).sub_(1) def forward(self, input): return F.linear(input, self.weight)
class StochasticConv2D(torch.nn.modules.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): """ StochasticConv2D class. Parameters are same to PyTorch's Conv2d. """ super(StochasticConv2D, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode ) self.weight_log_std = Parameter(torch.Tensor(self.weight.size())) self.bias_log_std = Parameter(torch.Tensor(self.bias.size())) self.weight_prior = Parameter(torch.Tensor(self.weight.size()), requires_grad=False) self.bias_prior = Parameter(torch.Tensor(self.bias.size()), requires_grad=False) self.weight_noise = Parameter(torch.Tensor(self.weight.size()), requires_grad=False) self.bias_noise = Parameter(torch.Tensor(self.bias.size()), requires_grad=False) def sample_noise(self): """ Sample weights from the posterior. :return: None """ self.realised_weight = self.weight + self.weight_noise.normal_() * torch.exp(self.weight_log_std) self.realised_bias = self.bias + self.bias_noise.normal_() * torch.exp(self.bias_log_std) def forward(self, input): if self.padding_mode == 'circular': expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, (self.padding[0] + 1) // 2, self.padding[0] // 2) return F.conv2d(F.pad(input, expanded_padding, mode='circular'), self.realised_weight, self.realised_bias, self.stride, _pair(0), self.dilation, self.groups) return F.conv2d(input, self.realised_weight, self.realised_bias, self.stride, self.padding, self.dilation, self.groups)
class StochasticLinear(nn.Module): __constants__ = ['bias', 'in_features', 'out_features'] def __init__(self, in_features, out_features): """ StochasticLinear class. Parameters are same to PyTorch's Linear class. """ super(StochasticLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) self.weight_log_std = Parameter(torch.Tensor(out_features, in_features)) self.weight_prior = Parameter(torch.Tensor(self.weight.size()), requires_grad=False) self.bias = Parameter(torch.Tensor(out_features)) self.bias_log_std = Parameter(torch.Tensor(out_features)) self.bias_prior = Parameter(torch.Tensor(self.bias.size()), requires_grad=False) self.weight_noise = Parameter(torch.Tensor(self.weight.size()), requires_grad=False) self.bias_noise = Parameter(torch.Tensor(self.bias.size()), requires_grad=False) def forward(self, input) -> torch.FloatTensor: return F.linear(input, self.realised_weight, self.realised_bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None ) def sample_noise(self) -> None: """ Sample weights and bias from posterior. :return: None """ self.realised_weight = self.weight + self.weight_noise.normal_() * torch.exp(self.weight_log_std) self.realised_bias = self.bias + self.bias_noise.normal_() * torch.exp(self.bias_log_std)
class NearestEmbedding(nn.Module): def __init__(self, embedding_num, embedding_dim): super(NearestEmbedding, self).__init__() self.weight = Parameter(torch.Tensor(embedding_num, embedding_dim)) self.reset_parameters() self.bn = nn.BatchNorm1d(embedding_dim) self.embedding_num = embedding_num def reset_parameters(self): self.weight.normal_(0, 1) def forward(self, input): input = self.bn(input) return NearestEmbeddingFunction.apply(input, self.weight)[0] def forward_onehot(self, input): input = self.bn(input) indices = NearestEmbeddingFunction.apply(input, self.weight)[1] onehot = input.new(input.size(0), self.embedding_num) onehot.zero_() onehot.scatter_(1, indices.view(-1, 1), 1) return onehot
class Convolution(nn.Module): r"""Performs a 2D convolution over an input spike-wave composed of several input planes. Current version only supports stride of 1 with no padding. The input is a 4D tensor with the size :math:`(T, C_{{in}}, H_{{in}}, W_{{in}})` and the crresponsing output is of size :math:`(T, C_{{out}}, H_{{out}}, W_{{out}})`, where :math:`T` is the number of time steps, :math:`C` is the number of feature maps (channels), and :math:`H`, and :math:`W` are the hight and width of the input/output planes. * :attr:`in_channels` controls the number of input planes (channels/feature maps). * :attr:`out_channels` controls the number of feature maps in the current layer. * :attr:`kernel_size` controls the size of the convolution kernel. It can be a single integer or a tuple of two integers. * :attr:`weight_mean` controls the mean of the normal distribution used for initial random weights. * :attr:`weight_std` controls the standard deviation of the normal distribution used for initial random weights. .. note:: Since this version of convolution does not support padding, it is the user responsibility to add proper padding on the input before applying convolution. Args: in_channels (int): Number of channels in the input. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. weight_mean (float, optional): Mean of the initial random weights. Default: 0.8 weight_std (float, optional): Standard deviation of the initial random weights. Default: 0.02 """ def __init__(self, in_channels, out_channels, kernel_size, weight_mean=0.8, weight_std=0.02): super(Convolution, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = to_pair(kernel_size) #self.weight_mean = weight_mean #self.weight_std = weight_std # For future use self.stride = 1 self.bias = None self.dilation = 1 self.groups = 1 self.padding = 0 # Parameters self.weight = Parameter( torch.Tensor(self.out_channels, self.in_channels, *self.kernel_size)) self.weight.requires_grad_(False) # We do not use gradients self.reset_weight(weight_mean, weight_std) def reset_weight(self, weight_mean=0.8, weight_std=0.02): """Resets weights to random values based on a normal distribution. Args: weight_mean (float, optional): Mean of the random weights. Default: 0.8 weight_std (float, optional): Standard deviation of the random weights. Default: 0.02 """ self.weight.normal_(weight_mean, weight_std) def load_weight(self, target): """Loads weights with the target tensor. Args: target (Tensor=): The target tensor. """ self.weight.copy_(target) def forward(self, input): return fn.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)