class twnLinear(nn.Linear): """ custom Linear layers for quantization """ def __init__(self, in_features, out_features, bias=True, cRate=0.7): super(twnLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias) self.weight_ternary = Parameter(torch.zeros(self.weight.data.size())) self.weight_alpha = Parameter(torch.ones(1)) self.weight_delta = 0 self.cRate = cRate def compute_grad(self): self.weight.grad = self.weight_ternary.grad # print self.weight_ternary.grad.data # print "alpha:", self.weight_alpha, "delta: ", self.weight_delta def forward(self, input): self.weight_delta = self.cRate * \ self.weight.abs().mean().clamp(min=0, max=10).data[0] self.weight_ternary.data.copy_( (self.weight.gt(self.weight_delta).float() - self.weight.lt(-self.weight_delta).float()).data) self.weight_alpha.data.copy_( ((self.weight.abs() * self.weight_ternary.abs()).sum() / self.weight_ternary.abs().sum()).clamp(min=0, max=10).data) return F.linear(input * self.weight_alpha.data[0], self.weight_ternary, self.bias)
class XNORConv2d(Module): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=True, dropout_ratio=0): super(XNORConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.groups = groups self.conv = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.conv.weight.data.normal_(0, 0.05) self.conv.bias.data.zero_() self.fp_weights = Parameter(zeros(self.conv.weight.size())) self.fp_weights.data.copy_(self.conv.weight.data) def forward(self, x): self.fp_weights.data = self.fp_weights.data - self.fp_weights.data.mean( 1, keepdim=True) self.fp_weights.data.clamp_(-1, 1) self.mean_val = self.fp_weights.abs().view(self.out_channels, -1).mean(1, keepdim=True) self.conv.weight.data.copy_(self.fp_weights.data.sign() * self.mean_val.view(-1, 1, 1, 1)) x = self.conv(x) return x def update_gradient(self): proxy = self.fp_weights.abs().sign() proxy[self.fp_weights.data.abs() > 1] = 0 binary_grad = self.conv.weight.grad * self.mean_val.view(-1, 1, 1, 1) * proxy mean_grad = self.conv.weight.data.sign() * self.conv.weight.grad mean_grad = mean_grad.view(self.out_channels, -1).mean(1).view(-1, 1, 1, 1) mean_grad = mean_grad * self.conv.weight.data.sign() self.fp_weights.grad = binary_grad + mean_grad self.fp_weights.grad = self.fp_weights.grad * self.fp_weights.data[ 0].nelement() * (1 - 1 / self.fp_weights.data.size(1))
class FactorizedSpatialTransformerPyramid2d(SpatialTransformerPyramid2d): def __init__(self, in_shape, outdims, scale_n=4, positive=False, bias=True, init_range=.1, downsample=True, type=None): super(SpatialTransformerPyramid2d, self).__init__() self.in_shape = in_shape c, w, h = in_shape self.outdims = outdims self.positive = positive self.gauss_pyramid = Pyramid(scale_n=scale_n, downsample=downsample, type=type) self.grid = Parameter(torch.Tensor(1, outdims, 1, 2)) self.feature_scales = Parameter( torch.Tensor(1, scale_n + 1, 1, outdims)) self.feature_channels = Parameter(torch.Tensor(1, 1, c, outdims)) if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter('bias', bias) else: self.register_parameter('bias', None) self.init_range = init_range self.initialize() @property def features(self): return (self.feature_scales * self.feature_channels).view( 1, -1, 1, self.outdims) def scale_l1(self, average=True): if average: return self.feature_scales.abs().mean() else: return self.feature_scales.abs().sum() def channel_l1(self, average=True): if average: return self.feature_channels.abs().mean() else: return self.feature_channels.abs().sum() def initialize(self): self.grid.data.uniform_(-self.init_range, self.init_range) self.feature_scales.data.fill_(1 / np.sqrt(self.in_shape[0])) self.feature_channels.data.fill_(1 / np.sqrt(self.in_shape[0])) if self.bias is not None: self.bias.data.fill_(0)
class NormalVar(Sampler): def __init__(self, input_channel, *args, init_logvar=1, **kwargs): super(NormalVar, self).__init__(*args,**kwargs) self.prec = Parameter(torch.ones(1,input_channel,1,1)*init_logvar) self.register_parameter('logvar',self.prec) self.prec.requires_grad=True def sample_normal(self,inputs,prec): noise = ((torch.randn_like(inputs) / ((prec).abs().sqrt()))).detach() return noise def sample_normal_nat(self, loc, scale): mean = loc/scale var = 1/scale output = ((torch.randn_like(loc)* ((var).abs().sqrt()))).detach() return None def log_prob_normal(self,state,prec): logprob = -(((state) ** 2) * ((prec).abs()) / 2) + prec.abs().log() / 2 return logprob def forward(self, inputs, concentration=1): inputs = inputs.detach() if not self.training: return inputs,0 prec = self.prec.abs() conc_prec = prec*concentration noise = self.sample_normal(inputs,prec) noise_conc= self.sample_normal(inputs,conc_prec) logprob = self.log_prob_normal(noise,prec) logprob_conc = self.log_prob_normal(noise_conc,prec) output = inputs + noise logprob = (logprob-logprob_conc).sum(dim=(1,2,3),keepdim=True).squeeze() return output,logprob
class twnConv2d(nn.Conv2d): """ custom convolutional layers for quantization """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, cRate=0.7): super(twnConv2d, self).__init__(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) self.weight_ternary = Parameter(torch.zeros(self.weight.data.size())) self.weight_alpha = Parameter(torch.ones(1)) self.weight_delta = 0 self.cRate = cRate def compute_grad(self): self.weight.grad = self.weight_ternary.grad # print self.weight_ternary # print self.weight_ternary.grad.data # print "alpha:", self.weight_alpha, "delta: ", self.weight_delta # assert False def forward(self, input): self.weight_delta = self.cRate * \ self.weight.abs().mean().clamp(min=0, max=10).data[0] self.weight_ternary.data.copy_( (self.weight.gt(self.weight_delta).float() - self.weight.lt(-self.weight_delta).float()).data) self.weight_alpha.data.copy_( ((self.weight.abs() * self.weight_ternary.abs()).sum() / self.weight_ternary.abs().sum()).clamp(min=0, max=10).data) return F.conv2d(input * self.weight_alpha.data[0], self.weight_ternary, self.bias, self.stride, self.padding, self.dilation, self.groups)
class ActNorm(nn.Module): """Activation normalization, two ways to initialize: - data init: one minibatch of data - identity transform: used in sampling-based training of cGlow Args: in_features (Tensor): Number of input features return_logdet (bool): default True. data_init (bool): Use one minibatch data initialization or not, default False. """ def __init__(self, in_features, return_logdet=True, data_init=False): super(ActNorm, self).__init__() # identify transform self.weight = Parameter(torch.ones(in_features, 1, 1)) self.bias = Parameter(torch.zeros(in_features, 1, 1)) self.data_init = data_init self.data_initialized = False self.return_logdet = return_logdet def _init_parameters(self, input): # input: initial minibatch data # mean per channel: (B, C, H, W) --> (C, B, H, W) --> (C, BHW) input = input.transpose(0, 1).contiguous().view(input.shape[1], -1) mean = input.mean(1) std = input.std(1) + 1e-6 self.bias.data = -(mean / std).unsqueeze(-1).unsqueeze(-1) self.weight.data = 1. / std.unsqueeze(-1).unsqueeze(-1) def forward(self, x): if self.data_init and (not self.data_initialized): self._init_parameters(x) self.data_initialized = True if self.return_logdet: logdet = self.weight.abs().log().sum() * x.shape[-1] * x.shape[-2] return self.weight * x + self.bias, logdet else: return self.weight * x + self.bias def reverse(self, y): if self.return_logdet: logdet = self.weight.abs().log().sum() * y.shape[-1] * y.shape[-2] return (y - self.bias) / self.weight, logdet else: return (y - self.bias) / self.weight
def test_quant_num_grad_align_zero(): # TODO: we should add gradients to `clamp` op here x = torch.randn(1, 3, 224, 224, requires_grad=True, dtype=DTYPE, device=DEVICE) d_qx = torch.randn_like(x).detach() lb = Parameter(x.detach().min() + 0.1) ub = Parameter(x.detach().max() - 0.1) k = 8 # autograd implementation assert ub.detach() - lb.detach() > 1e-2 qx = fake_linear_quant(x, lb, ub, k, align_zero=True) qx.backward(d_qx) qx_gt = qx.detach() d_lb_gt = lb.grad.detach() d_ub_gt = ub.grad.detach() d_x_gt = x.grad.detach() # CUDA numerical implementation lb.grad.data.zero_() ub.grad.data.zero_() x.grad.data.zero_() qx = cuda_fake_linear_quant(x, lb, ub, k, align_zero=True) qx.backward(d_qx) qx_cuda = qx.detach() d_lb_cuda = lb.grad.detach() d_ub_cuda = ub.grad.detach() d_x_cuda = x.grad.detach() assert torch.allclose(qx_cuda, qx_gt) assert torch.allclose(d_lb_cuda, d_lb_gt) assert torch.allclose(d_ub_cuda, d_ub_gt) assert torch.allclose(d_x_cuda, d_x_gt) # numerical grad implementation with torch.no_grad(): N = torch.tensor(2 ** k - 1, dtype=DTYPE, device=DEVICE) delta = ub.sub(lb).div(N) z = torch.round(lb.abs().div(delta)) lb_ = z.neg().mul(delta) ub_ = (N - z).mul(delta) x_mask = (lb_ <= x) & (x <= ub_) # pre-compute mask x = torch.clamp(x, lb_.item(), ub_.item()) i = torch.round(x.sub(lb_).div(delta)) # after forward, calculate cache x_sub = x - lb_ - torch.abs(lb) d_i = (i - z) - (x_sub / delta) d_lb, d_ub = d_lb_ub(d_qx, d_i, N, torch.sign(lb)) dx = d_x(d_qx, x_mask) assert torch.allclose(d_lb_gt, d_lb) assert torch.allclose(d_ub_gt, d_ub) assert torch.allclose(dx, d_x_gt)
class XNORLinear(Module): def __init__(self, in_features, out_features, bias=True): super(XNORLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.bias = bias self.linear = Linear(in_features=in_features, out_features=out_features, bias=bias) self.fp_weights = Parameter(zeros(self.linear.weight.size())) self.fp_weights.data.copy_(self.linear.weight.data) def forward(self, x): self.fp_weights.data = self.fp_weights.data - self.fp_weights.data.mean( 1, keepdim=True) self.fp_weights.data.clamp_(-1, 1) self.mean_val = self.fp_weights.abs().view(self.out_features, -1).mean(1, keepdim=True) self.linear.weight.data.copy_(self.fp_weights.data.sign() * self.mean_val.view(-1, 1)) x = self.linear(x) return x def update_gradient(self): proxy = self.fp_weights.abs().sign() proxy[self.fp_weights.data.abs() > 1] = 0 binary_grad = self.linear.weight.grad * self.mean_val.view(-1, 1) * proxy mean_grad = self.linear.weight.data.sign() * self.linear.weight.grad mean_grad = mean_grad.view(self.out_features, -1).mean(1).view(-1, 1) mean_grad = mean_grad * self.linear.weight.data.sign() self.fp_weights.grad = binary_grad + mean_grad self.fp_weights.grad = self.fp_weights.grad * self.fp_weights.data[ 0].nelement() * (1 - 1 / self.fp_weights.data.size(1)) return
class GraphSIR(torch.nn.Module): """A SIR model on Graph that considers considers travelling of the infected populations """ def __init__(self, intra_b, intra_k, inter_adj, inter_b, device='cpu'): """ Definition of the coefficients follows the SIR model Args: intra_b (TYPE): intra_city transmission probability, each city can have a different values, it depends on how crowded the city is intra_k (TYPE): intra_city recovering probability, each city can have a different values, it depends on how crowded the city is inter_adj (TYPE): a integer tensor of size (# of edges, 2), here we use a undirected graphs to ensure detail balance inter_b (TYPE): The travelling probability of the infected device (str, optional): which device to run this model, CPU is the default. requires CUDA-enabled GPU """ super().__init__() self.N = intra_k.shape[0] # number of nodes self.intra_b = Parameter( intra_b.to(device)) # b: infection probability within the city self.intra_k = Parameter( intra_k.to(device)) # k: healing probability within the city self.inter_adj = inter_adj.to( device) # adjacency matrix among all the cities in the models self.inter_b = Parameter( inter_b) # inter_b: infection coupling among different cities self.device = device # what device to use, "cpu" as default def forward(self, t, s): dsdt = torch.zeros(self.N, 3).to(self.device) # infected from i to j i_2_j = self.inter_b.abs() * s[self.inter_adj[:, 0], 1] di_inter = scatter_add( i_2_j, self.inter_adj[:, 1], dim_size=self.N) - scatter_add( i_2_j, self.inter_adj[:, 0], dim_size=self.N) j_2_i = self.inter_b.abs() * s[self.inter_adj[:, 1], 1] di_inter += scatter_add( j_2_i, self.inter_adj[:, 0], dim_size=self.N) - scatter_add( j_2_i, self.inter_adj[:, 1], dim_size=self.N) # update the inter-city dependence dsdt[:, 1] += di_inter # Intra city development ds_intra = -s[:, 0] * s[:, 1] * self.intra_b.abs() di_intra = s[:, 0] * s[:, 1] * self.intra_b.abs( ) - s[:, 1] * self.intra_k.abs() dr_intra = s[:, 1] * self.intra_k.abs() # update the intra city dependence dsdt[:, 0] += ds_intra dsdt[:, 1] += di_intra dsdt[:, 2] += dr_intra return dsdt
class DiffBoundary: def __init__(self, bit_width=4): # TODO: add channel-wise option? self.bit_width = bit_width self.register_boundaries() def register_boundaries(self): assert hasattr(self, "weight") self.lb = Parameter(self.weight.data.min()) self.ub = Parameter(self.weight.data.max()) def reset_boundaries(self): assert hasattr(self, "weight") self.lb.data = self.weight.data.min() self.ub.data = self.weight.data.max() def get_quant_weight(self, align_zero=True): # TODO: set `align_zero`? if align_zero: return self._get_quant_weight_align_zero() else: return self._get_quant_weight() def _get_quant_weight(self): round_ = RoundSTE.apply w = self.weight.detach() delta = (self.ub - self.lb) / (2**self.bit_width - 1) w = torch.clamp(w, self.lb.item(), self.ub.item()) idx = round_((w - self.lb).div(delta)) # TODO: do we need STE here? qw = (idx * delta) + self.lb return qw def _get_quant_weight_align_zero(self): # TODO: WTF? round_ = RoundSTE.apply n = 2**self.bit_width - 1 w = self.weight.detach() delta = (self.ub - self.lb) / n z = round_(self.lb.abs() / delta) lb = -z * delta ub = (n - z) * delta w = torch.clamp(w, lb.item(), ub.item()) idx = round_((w - self.lb).div(delta)) # TODO: do we need STE here? qw = (idx - z) * delta return qw
class Connection(AbstractConnection): # language=rst """ Specifies synapses between one or two populations of neurons. """ def __init__(self, source: Nodes, target: Nodes, nu: Optional[Union[float, Sequence[float]]] = None, weight_decay: float = 0.0, **kwargs) -> None: # language=rst """ Instantiates a :code:`Connection` object. :param source: A layer of nodes from which the connection originates. :param target: A layer of nodes to which the connection connects. :param nu: Learning rate for both pre- and post-synaptic events. :param weight_decay: Constant multiple to decay weights by on each iteration. Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. :param float wmin: Minimum allowed value on the connection weights. :param float wmax: Maximum allowed value on the connection weights. :param float norm: Total weight per target neuron normalization constant. :param ByteTensor norm_by_max: Normalize the weight of a neuron by its max weight. :param ByteTensor norm_by_max_with_shadow_weights: Normalize the weight of a neuron by its max weight by original weights. """ super().__init__(source, target, nu, weight_decay, **kwargs) w = kwargs.get("w", None) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, False) self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), False) if self.norm_by_max_from_shadow_weights: self.shadow_w = self.w.clone().detach() self.prev_w = self.w.clone().detach() def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ # Compute multiplication of spike activations by connection weights and add bias. post = s.float().view(-1) @ self.w + self.b return post.view(*self.target.shape) def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. """ super().update(**kwargs) def normalize(self) -> None: # language=rst """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def normalize_by_max(self) -> None: # language=rst """ Normalize weights by the max weight of the target neuron. """ if self.norm_by_max: w_max = self.w.abs().max(0)[0] w_max[w_max == 0] = 1.0 self.w /= w_max def normalize_by_max_from_shadow_weights(self) -> None: # language=rst """ Normalize weights by the max weight of the target neuron. """ if self.norm_by_max_from_shadow_weights: self.shadow_w += self.w - self.prev_w w_max = self.shadow_w.abs().max(0)[0] w_max[w_max == 0] = 1.0 self.w = self.shadow_w / w_max self.prev_w = self.w.clone().detach() def reset_(self) -> None: # language=rst """ Contains resetting logic for the connection. """ super().reset_()
class Connection(AbstractConnection): # full connection # language=rst """ Specifies synapses between one or two populations of neurons. """ def __init__(self, source: Nodes, target: Nodes, nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, **kwargs) -> None: # language=rst """ Instantiates a :code:`Connection` object. :param source: A layer of nodes from which the connection originates. :param target: A layer of nodes to which the connection connects. :param nu: Learning rate for both pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. :param float wmin: Minimum allowed value on the connection weights. :param float wmax: Maximum allowed value on the connection weights. :param float norm: Total weight per target neuron normalization constant. """ super().__init__(source, target, nu, reduction, weight_decay, **kwargs) w = kwargs.get("w", None) # 此处产生 w , 根据 source 和 target 的形状产生对应的矩阵 if w is None: # 若未设置w初值 if self.wmin == -np.inf or self.wmax == np.inf: # 若w的上下限未被全部设置 w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) # 包含了从区间0-1 中的随机数作为初值 else: w = self.wmin + torch.rand(source.n, target.n) * ( self.wmax - self.wmin) # 设置上下限 else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) b = kwargs.get("b", None) if b is not None: self.b = Parameter(b, requires_grad=False) else: self.b = None if isinstance(self.target, CSRMNodes): self.s_w = None def compute( self, s: torch.Tensor ) -> torch.Tensor: # 关键的函数: 输入: incoming spikes(从source层中获取) 输出: 经过权重乘积得到的输入target层的值 # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ # Compute multiplication of spike activations by weights and add bias. if self.b is None: post = s.view(s.size(0), -1).float() @ self.w # @ :matrix multi vector else: post = s.view(s.size(0), -1).float() @ self.w + self.b return post.view(s.size(0), *self.target.shape) def compute_window(self, s: torch.Tensor) -> torch.Tensor: # language=rst """""" if self.s_w == None: # Construct a matrix of shape batch size * window size * dimension of layer self.s_w = torch.zeros(self.target.batch_size, self.target.res_window_size, *self.source.shape) # Add the spike vector into the first in first out matrix of windowed (res) spike trains self.s_w = torch.cat((self.s_w[:, 1:, :], s[:, None, :]), 1) # Compute multiplication of spike activations by weights and add bias. if self.b is None: post = (self.s_w.view(self.s_w.size(0), self.s_w.size(1), -1).float() @ self.w) else: post = (self.s_w.view(self.s_w.size(0), self.s_w.size(1), -1).float() @ self.w + self.b) return post.view(self.s_w.size(0), self.target.res_window_size, *self.target.shape) def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. """ super().update(**kwargs) def normalize(self) -> None: # language=rst """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def reset_state_variables(self) -> None: # language=rst """ Contains resetting logic for the connection. """ super().reset_state_variables()
class PointPooled2d(Readout): def __init__( self, in_shape, outdims, pool_steps, bias, pool_kern, init_range, align_corners=True, mean_activity=None, feature_reg_weight=1.0, gamma_readout=None, # depricated, use feature_reg_weight instead **kwargs, ): """ This readout learns a point in the core feature space for each neuron, with help of torch.grid_sample, that best predicts its response. Multiple average pooling steps are applied to reduce search space in each stage and thereby, faster convergence to the best prediction point. The readout receives the shape of the core as 'in_shape', number of pooling stages to be performed as 'pool_steps', the kernel size and stride length to be used for pooling as 'pool_kern', the number of units/neurons being predicted as 'outdims', 'bias' specifying whether or not bias term is to be used and 'init_range' range for initialising the grid with uniform distribution, U(-init_range,init_range). The grid parameter contains the normalized locations (x, y coordinates in the core feature space) and is clipped to [-1.1] as it a requirement of the torch.grid_sample function. The feature parameter learns the best linear mapping from the pooled feature map from a given location to a unit's response with or without an additional elu non-linearity. Args: in_shape (list): shape of the input feature map [channels, width, height] outdims (int): number of output units pool_steps (int): number of pooling stages bias (bool): adds a bias term pool_kern (int): filter size and stride length used for pooling the feature map init_range (float): intialises the grid with Uniform([-init_range, init_range]) [expected: positive value <=1] align_corners (bool): Keyword agrument to gridsample for bilinear interpolation. It changed behavior in PyTorch 1.3. The default of align_corners = True is setting the behavior to pre PyTorch 1.3 functionality for comparability. """ super().__init__() if init_range > 1.0 or init_range <= 0.0: raise ValueError("init_range is not within required limit!") self._pool_steps = pool_steps self.in_shape = in_shape c, w, h = in_shape self.outdims = outdims self.feature_reg_weight = self.resolve_deprecated_gamma_readout( feature_reg_weight, gamma_readout) self.mean_activity = mean_activity self.grid = Parameter(torch.Tensor( 1, outdims, 1, 2)) # x-y coordinates for each neuron self.features = Parameter( torch.Tensor(1, c * (self._pool_steps + 1), 1, outdims) ) # weight matrix mapping the core features to the output units if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter("bias", bias) else: self.register_parameter("bias", None) self.pool_kern = pool_kern self.avg = nn.AvgPool2d( (pool_kern, pool_kern), stride=pool_kern, count_include_pad=False ) # setup kernel of size=[pool_kern,pool_kern] with stride=pool_kern self.init_range = init_range self.align_corners = align_corners self.initialize(mean_activity) @property def pool_steps(self): return self._pool_steps @pool_steps.setter def pool_steps(self, value): assert value >= 0 and int( value ) - value == 0, "new pool steps must be a non-negative integer" if value != self._pool_steps: logger.info("Resizing readout features") c, w, h = self.in_shape self._pool_steps = int(value) self.features = Parameter( torch.Tensor(1, c * (self._pool_steps + 1), 1, self.outdims)) self.features.data.fill_(1 / self.in_shape[0]) def initialize(self, mean_activity=None): """ Initialize function initialises the grid, features or weights and bias terms. """ if mean_activity is None: mean_activity = self.mean_activity self.grid.data.uniform_(-self.init_range, self.init_range) self.features.data.fill_(1 / self.in_shape[0]) if self.bias is not None: self.initialize_bias(mean_activity=mean_activity) def feature_l1(self, reduction="sum", average=None): """ Returns l1 regularization term for features. Args: average(bool): Deprecated (see reduction) if True, use mean of weights for regularization reduction(str): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' """ return self.apply_reduction(self.features.abs(), reduction=reduction, average=average) def regularizer(self, reduction="sum", average=None): return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight def forward(self, x, shift=None, out_idx=None, **kwargs): """ Propagates the input forwards through the readout Args: x: input data shift: shifts the location of the grid (from eye-tracking data) out_idx: index of neurons to be predicted Returns: y: neuronal activity """ self.grid.data = torch.clamp(self.grid.data, -1, 1) N, c, w, h = x.size() c_in, w_in, h_in = self.in_shape if [c_in, w_in, h_in] != [c, w, h]: raise ValueError( "the specified feature map dimension is not the readout's expected input dimension" ) m = self.pool_steps + 1 # the input feature is considered the first pooling stage feat = self.features.view(1, m * c, self.outdims) if out_idx is None: grid = self.grid bias = self.bias outdims = self.outdims else: if isinstance(out_idx, np.ndarray): if out_idx.dtype == bool: out_idx = np.where(out_idx)[0] feat = feat[:, :, out_idx] grid = self.grid[:, out_idx] if self.bias is not None: bias = self.bias[out_idx] outdims = len(out_idx) if shift is None: grid = grid.expand(N, outdims, 1, 2) else: # shift grid based on shifter network's prediction grid = grid.expand(N, outdims, 1, 2) + shift[:, None, None, :] pools = [F.grid_sample(x, grid, align_corners=self.align_corners)] for _ in range(self.pool_steps): _, _, w_pool, h_pool = x.size() if w_pool * h_pool == 1: warnings.warn( "redundant pooling steps: pooled feature map size is already 1X1, consider reducing it" ) x = self.avg(x) pools.append( F.grid_sample(x, grid, align_corners=self.align_corners)) y = torch.cat(pools, dim=1) y = (y.squeeze(-1) * feat).sum(1).view(N, outdims) if self.bias is not None: y = y + bias return y def __repr__(self): c, w, h = self.in_shape r = self.__class__.__name__ + " (" + "{} x {} x {}".format( c, w, h) + " -> " + str(self.outdims) + ")" if self.bias is not None: r += " with bias" r += " and pooling for {} steps\n".format(self.pool_steps) for ch in self.children(): r += " -> " + ch.__repr__() + "\n" return r
class Connection(AbstractConnection): # language=rst """ Specifies synapses between one or two populations of neurons. """ def __init__(self, source: Nodes, target: Nodes, nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, **kwargs) -> None: # language=rst """ Instantiates a :code:`Connection` object. :param source: A layer of nodes from which the connection originates. :param target: A layer of nodes to which the connection connects. :param nu: Learning rate for both pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. :param float wmin: Minimum allowed value on the connection weights. :param float wmax: Maximum allowed value on the connection weights. :param float norm: Total weight per target neuron normalization constant. """ super().__init__(source, target, nu, reduction, weight_decay, **kwargs) w = kwargs.get("w", None) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) b = kwargs.get("b", None) if b is not None: self.b = Parameter(b, requires_grad=False) else: self.b = None def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ # Compute multiplication of spike activations by weights and add bias. if self.b is None: post = s.view(s.size(0), -1).float() @ self.w else: post = s.view(s.size(0), -1).float() @ self.w + self.b return post.view(s.size(0), *self.target.shape) def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. """ super().update(**kwargs) def normalize(self) -> None: # language=rst """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def reset_state_variables(self) -> None: # language=rst """ Contains resetting logic for the connection. """ super().reset_state_variables()
class PointPyramid2d(Readout): def __init__( self, in_shape, outdims, scale_n, positive, bias, init_range, downsample, type, align_corners=True, mean_activity=None, feature_reg_weight=1.0, gamma_readout=None, # depricated, use feature_reg_weight instead **kwargs, ): super().__init__() self.in_shape = in_shape c, w, h = in_shape self.outdims = outdims self.positive = positive self.feature_reg_weight = self.resolve_deprecated_gamma_readout( feature_reg_weight, gamma_readout) self.mean_activity = mean_activity self.gauss_pyramid = Pyramid(scale_n=scale_n, downsample=downsample, type=type) self.grid = Parameter(torch.Tensor(1, outdims, 1, 2)) self.features = Parameter( torch.Tensor(1, c * (scale_n + 1), 1, outdims)) if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter("bias", bias) else: self.register_parameter("bias", None) self.init_range = init_range self.align_corners = align_corners self.initialize(mean_activity) def initialize(self, mean_activity=None): if mean_activity is None: mean_activity = self.mean_activity self.grid.data.uniform_(-self.init_range, self.init_range) self.features.data.fill_(1 / self.in_shape[0]) if self.bias is not None: self.initialize_bias(mean_activity=mean_activity) def group_sparsity(self, group_size): f = self.features.size(1) n = f // group_size ret = 0 for chunk in range(0, f, group_size): ret = ret + (self.features[:, chunk:chunk + group_size, ...].pow( 2).mean(1) + 1e-12).sqrt().mean() / n return ret def feature_l1(self, reduction="sum", average=None): return self.apply_reduction(self.features.abs(), reduction=reduction, average=average) def regularizer(self, reduction="sum", average=None): return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight def forward(self, x, shift=None): if self.positive: self.features.data.clamp_min_(0) self.grid.data = torch.clamp(self.grid.data, -1, 1) N, c, w, h = x.size() m = self.gauss_pyramid.scale_n + 1 feat = self.features.view(1, m * c, self.outdims) if shift is None: grid = self.grid.expand(N, self.outdims, 1, 2) else: grid = self.grid.expand(N, self.outdims, 1, 2) + shift[:, None, None, :] pools = [ F.grid_sample(xx, grid, align_corners=self.align_corners) for xx in self.gauss_pyramid(x) ] y = torch.cat(pools, dim=1).squeeze(-1) y = (y * feat).sum(1).view(N, self.outdims) if self.bias is not None: y = y + self.bias return y def __repr__(self): c, w, h = self.in_shape r = self.__class__.__name__ + " (" + "{} x {} x {}".format( c, w, h) + " -> " + str(self.outdims) + ")" if self.bias is not None: r += " with bias" for ch in self.children(): r += " -> " + ch.__repr__() + "\n" return r
class DynamicConnection(AbstractConnection): # language=rst """ Specifies synapses between one or two populations of neurons. The weight matrix is allowed to rewire dynamically """ def __init__(self, source: Nodes, target: Nodes, nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, **kwargs) -> None: # language=rst """ Instantiates a :code:`DynamicConnection` object. :param source: A layer of nodes from which the connection originates. :param target: A layer of nodes to which the connection connects. :param nu: Learning rate for both pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. :param float wmin: Minimum allowed value on the connection weights. :param float wmax: Maximum allowed value on the connection weights. :param float norm: Total weight per target neuron normalization constant. :param prune_thresh: Weight threshold for pruning :param prune_prob: Probability for pruning :param create_prob: Probability for probabalistic synaptogenesis :param create: Enable activity dependent synaptogenesis """ super().__init__(source, target, nu, reduction, weight_decay, **kwargs) w = kwargs.get("w", None) prune_thresh = kwargs.get("prune_thresh", 0.0) prune_prob = kwargs.get("prune_prob", 0.0) create_prob = kwargs.get("create_prob", 0.0) create = kwargs.get("create", False) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), requires_grad=False) self.prune_thresh = prune_thresh self.prune_prob = prune_prob self.create_prob = create_prob self.create = create def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ # Compute multiplication of spike activations by weights and add bias. post = s.float().view(s.size(0), -1) @ self.w + self.b return post.view(s.size(0), *self.target.shape) def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. This is for functional and structural plasticity It implements a form of pruning by forcing matrix entries to zero according to a threshold """ # call regular functional plasticity rule super().update(**kwargs) def sp(self) -> Tuple: # language=rst """ Runs structural plasticity """ total_conns_created = 0 total_conns_pruned = 0 #print("Dynamic Weights before structural plasticity", self.w.data) # Synaptogenesis mechanisms if self.create_prob > 0.0: #print("Probabalistic synaptogenesis") # Create a probability mask create_mask = torch.rand(self.w.data.shape) #print("probs", create_mask) create_mask[create_mask < self.create_prob] = 0.0 create_mask[create_mask >= self.create_prob] = 1.0 #print("mask",create_mask) #print("wt",self.w.data) #print((create_mask==0.0).sum().data) #print(self.w.data[(create_mask == 0.0) & (self.w.data == 0.0)].data.shape) self.w.data[(create_mask == 0.0) & (self.w.data == 0.0)] = 0.3 * (np.random.uniform( self.wmin, self.wmax)) total_conns_created += (create_mask == 0.0).sum().data if self.create: #print("Activity dependent synaptogenesis") # get the source and target activity traces batch_size = self.source.batch_size source_x = self.source.x.view(batch_size, -1).unsqueeze(2) target_x = self.target.x.view(batch_size, -1).unsqueeze(1) # Create masks where the source and target traces are > 0.5 # i.e. the neurons that have been recently 'active' source_mask = torch.zeros_like(source_x) source_mask = source_mask.type(torch.BoolTensor) target_mask = torch.zeros_like(target_x) target_mask = target_mask.type(torch.BoolTensor) #print("src activity max", torch.max(source_x)) #print("target activity max", torch.max(target_x)) #print("zero weights", (self.w == 0.0).sum().data) source_mask[(source_x.data > 0.50)] = True target_mask[(target_x.data > 0.50)] = True #print("source x", source_x) #print("source x mask", source_mask, torch.max(source_mask)) #print("target x", target_x) #print("target x mask", target_mask, torch.max(target_mask)) # Need to check if the source and target mask tensors actually have any True entries # otherwise not worth proceeding! #print("Dynamic Weights before structural plasticity", self.w.data) if (torch.max(source_mask) == True) and (torch.max(target_mask) == True): #if (torch.max(source_mask[0,0,:]) == True) and (torch.max(target_mask[0,0,:]) == True): # Create a mask of random weight values between min and max # zero all values where the weight matrix is not zero weight_mask = torch.FloatTensor(self.w.data.shape[0], self.w.data.shape[1]).uniform_( self.wmin, self.wmax) * 0.3 #print(weight_mask) weight_mask[(self.w.data != 0.0)] = 0.0 # Here we 'not' the source and target masks so we can easily # use them to set the weight matrix to zero where activity is not > 0.5 source_mask = torch.logical_not(source_mask) target_mask = torch.logical_not(target_mask) weight_mask[source_mask[0, :, 0], :] = 0.0 weight_mask[:, target_mask[0, 0, :]] = 0.0 #print(weight_mask) #print("Weights to create", (weight_mask!=0.0).sum().data) # Now simply add the weight mask to the weights. # This should have the effect of setting 'new' weights # only where the source and target traces values are above # threshold and where the weights were previously zero self.w.data += weight_mask total_conns_created += (weight_mask != 0.0).sum().data #self.w.data[source_mask[0,:,0],target_mask[0,0,:]] = self.w.data[source_mask[0,:,0],target_mask[0,0,:]] + weight_mask[source_mask[0,:,0],target_mask[0,0,:]] else: pass #print("Nothing to update") #print("Dynamic Weights after structural plasticity", self.w.data) # Connection pruning mechanisms if self.prune_thresh > 0.0: #print("Threshold pruning") # Threshold pruning # set all values less than the threshold to zero # We have to handle negative weights as well # Create a mask so we can calculate the number # of connections that will be pruned prune_mask = torch.ones(self.w.data.shape) prune_mask[(self.w.data > 0.0) & (self.w.data < self.prune_thresh)] = 0.0 prune_mask[(self.w.data < 0.0) & (self.w.data > -(self.prune_thresh))] = 0.0 # Prune the actual connections self.w.data[(self.w.data > 0.0) & (self.w.data < self.prune_thresh)] = 0.0 self.w.data[(self.w.data < 0.0) & (self.w.data > -(self.prune_thresh))] = 0.0 total_conns_pruned += (prune_mask == 0.0).sum().data if self.prune_prob > 0.0: #print("Probabalistic pruning") # Probabalistic pruning # Create a probability mask prune_mask = torch.rand(self.w.data.shape) #print("probs", prune_mask) prune_mask[prune_mask < self.prune_prob] = 0.0 prune_mask[prune_mask >= self.prune_prob] = 1.0 #print("prune_mask", prune_mask) #print("wt", self.w.data) #print((prune_mask==0.0).sum().data) self.w.data = prune_mask * self.w.data total_conns_pruned += (prune_mask == 0.0).sum().data return (total_conns_created, total_conns_pruned) def normalize(self) -> None: # language=rst """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def reset_state_variables(self) -> None: # language=rst """ Contains resetting logic for the connection. """ super().reset_state_variables()
class SpatialTransformerPyramid3d(nn.Module): def __init__(self, in_shape, outdims, scale_n=4, positive=True, bias=True, init_range=.05, downsample=True, _skip_upsampling=False, type=None): super().__init__() self.in_shape = in_shape c, _, w, h = in_shape self.outdims = outdims self.positive = positive self.gauss = Pyramid(scale_n=scale_n, downsample=downsample, _skip_upsampling=_skip_upsampling, type=type) self.grid = Parameter(torch.Tensor(1, outdims, 1, 2)) self.features = Parameter(torch.Tensor(1, c * (scale_n + 1), 1, outdims)) if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter('bias', bias) else: self.register_parameter('bias', None) self.init_range = init_range self.initialize() def initialize(self): self.grid.data.uniform_(-self.init_range, self.init_range) self.features.data.fill_(1 / self.in_shape[0]) if self.bias is not None: self.bias.data.fill_(0) def feature_l1(self, average=True, subs_idx=None): if subs_idx is not None: raise NotImplemented('Subsample is not implemented.') if average: return self.features.abs().mean() else: return self.features.abs().sum() def forward(self, x, shift=None, subs_idx=None): if subs_idx is not None: raise NotImplemented('Subsample is not implemented.') if self.positive: positive(self.features) self.grid.data = torch.clamp(self.grid.data, -1, 1) N, c, t, w, h = x.size() m = self.gauss.scale_n + 1 feat = self.features.view(1, m * c, self.outdims) if shift is None: grid = self.grid.expand(N * t, self.outdims, 1, 2) else: grid = self.grid.expand(N, self.outdims, 1, 2) grid = torch.stack([grid + shift[:, i, :][:, None, None, :] for i in range(t)], 1) grid = grid.contiguous().view(-1, self.outdims, 1, 2) z = x.contiguous().transpose(2, 1).contiguous().view(-1, c, w, h) pools = [F.grid_sample(x, grid) for x in self.gauss(z)] y = torch.cat(pools, dim=1).squeeze(-1) y = (y * feat).sum(1).view(N, t, self.outdims) if self.bias is not None: y = y + self.bias return y def __repr__(self): c, t, w, h = self.in_shape r = self.__class__.__name__ + \ ' (' + '{} x {} x {}'.format(c, w, h) + ' -> ' + str(self.outdims) + ')' if self.bias is not None: r += ' with bias' for ch in self.children(): r += '\n -> ' + ch.__repr__() return r
class SpatialTransformerPooled2d(nn.Module): def __init__(self, in_shape, outdims, pool_steps=1, positive=False, bias=True, pool_kern=2, init_range=.1): super().__init__() self.pool_steps = pool_steps self.in_shape = in_shape c, w, h = in_shape self.outdims = outdims self.positive = positive self.grid = Parameter(torch.Tensor(1, outdims, 1, 2)) self.features = Parameter(torch.Tensor(1, c * (self.pool_steps + 1), 1, outdims)) if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter('bias', bias) else: self.register_parameter('bias', None) self.pool_kern = pool_kern self.avg = nn.AvgPool2d((pool_kern, pool_kern), stride=pool_kern, count_include_pad=False) self.init_range = init_range self.initialize() def initialize(self): self.grid.data.uniform_(-self.init_range, self.init_range) self.features.data.fill_(1 / self.in_shape[0]) if self.bias is not None: self.bias.data.fill_(0) def feature_l1(self, average=True): if average: return self.features.abs().mean() else: return self.features.abs().sum() def group_sparsity(self, group_size): f = self.features.size(1) n = f // group_size ret = 0 for chunk in range(0, f, group_size): ret = ret + (self.features[:, chunk:chunk + group_size, ...].pow(2).mean(1) + 1e-12).sqrt().mean() / n return ret def forward(self, x, shift=None): if self.positive: positive(self.features) self.grid.data = torch.clamp(self.grid.data, -1, 1) N, c, w, h = x.size() m = self.pool_steps + 1 feat = self.features.view(1, m * c, self.outdims) if shift is None: grid = self.grid.expand(N, self.outdims, 1, 2) else: grid = self.grid.expand(N, self.outdims, 1, 2) + shift[:, None, None, :] pools = [F.grid_sample(x, grid)] for _ in range(self.pool_steps): x = self.avg(x) pools.append(F.grid_sample(x, grid)) y = torch.cat(pools, dim=1) y = (y.squeeze(-1) * feat).sum(1).view(N, self.outdims) if self.bias is not None: y = y + self.bias return y def __repr__(self): c, w, h = self.in_shape r = self.__class__.__name__ + \ ' (' + '{} x {} x {}'.format(c, w, h) + ' -> ' + str(self.outdims) + ')' if self.bias is not None: r += ' with bias' r += ' and pooling for {} steps\n'.format(self.pool_steps) for ch in self.children(): r += ' -> ' + ch.__repr__() + '\n' return r
class SpatialTransformerPyramid2d(nn.Module): def __init__(self, in_shape, outdims, scale_n=4, positive=False, bias=True, init_range=.1, downsample=True, _skip_upsampling=False, type=None): super().__init__() self.in_shape = in_shape c, w, h = in_shape self.outdims = outdims self.positive = positive self.gauss_pyramid = Pyramid(scale_n=scale_n, downsample=downsample, _skip_upsampling=_skip_upsampling, type=type) self.grid = Parameter(torch.Tensor(1, outdims, 1, 2)) self.features = Parameter(torch.Tensor(1, c * (scale_n + 1), 1, outdims)) if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter('bias', bias) else: self.register_parameter('bias', None) self.init_range = init_range self.initialize() def initialize(self): self.grid.data.uniform_(-self.init_range, self.init_range) self.features.data.fill_(1 / self.in_shape[0]) if self.bias is not None: self.bias.data.fill_(0) def group_sparsity(self, group_size): f = self.features.size(1) n = f // group_size ret = 0 for chunk in range(0, f, group_size): ret = ret + (self.features[:, chunk:chunk + group_size, ...].pow(2).mean(1) + 1e-12).sqrt().mean() / n return ret def feature_l1(self, average=True): if average: return self.features.abs().mean() else: return self.features.abs().sum() def neuron_layer_power(self, x, neuron_id): if self.positive: positive(self.features) self.grid.data = torch.clamp(self.grid.data, -1, 1) N, c, w, h = x.size() m = self.gauss_pyramid.scale_n + 1 feat = self.features.view(1, m * c, self.outdims) y = torch.cat(self.gauss_pyramid(x), dim=1) y = (y * feat[:, :, neuron_id, None, None]).sum(1) return y.pow(2).mean() def forward(self, x, shift=None): if self.positive: positive(self.features) self.grid.data = torch.clamp(self.grid.data, -1, 1) N, c, w, h = x.size() m = self.gauss_pyramid.scale_n + 1 feat = self.features.view(1, m * c, self.outdims) if shift is None: grid = self.grid.expand(N, self.outdims, 1, 2) else: grid = self.grid.expand(N, self.outdims, 1, 2) + shift[:, None, None, :] pools = [F.grid_sample(xx, grid) for xx in self.gauss_pyramid(x)] y = torch.cat(pools, dim=1).squeeze(-1) y = (y * feat).sum(1).view(N, self.outdims) if self.bias is not None: y = y + self.bias return y def __repr__(self): c, w, h = self.in_shape r = self.__class__.__name__ + \ ' (' + '{} x {} x {}'.format(c, w, h) + ' -> ' + str(self.outdims) + ')' if self.bias is not None: r += ' with bias' for ch in self.children(): r += ' -> ' + ch.__repr__() + '\n' return r
class AllToAllConnection(ABC, Module): def __init__(self, source: Nodes, target: Nodes, w: None, tc_synaptic: float = 0.0, phi: float = 0.0, nu: Optional[Union[float, Sequence[float]]] = None, weight_decay: float = 0.0, **kwargs) -> None: """ :param source: A layer of nodes from which the connection originates. :param target: A layer of nodes to which the connection connects. :param nu: Learning rate for both pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias :param float wmin: The minimum value on the connection weights. :param float wmax: The maximum value on the connection weights. :param float norm: Total weight per target neuron normalization. """ super().__init__() # initialisation of Module assert isinstance(source, Nodes), "Source is not a Nodes object" assert isinstance(target, Nodes), "Target is not a Nodes object" self.source = source self.target = target self.nu = nu self.weight_decay = weight_decay # self.reduction = reduction self.update_rule = kwargs.get("update_rule", NoOp) self.wmin = kwargs.get("wmin", -np.inf) self.wmax = kwargs.get("wmax", np.inf) self.norm = kwargs.get("norm", None) # self.decay = kwargs.get("decay", None) # Learning rule if self.update_rule is None: self.update_rule = NoOp self.update_rule = self.update_rule( connection=self, nu=nu, # reduction=reduction, weight_decay=weight_decay, **kwargs) # Weights self.w = Parameter(w, requires_grad=False) self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), requires_grad=False) # Parameters used to update synaptic input self.active_neurotransmitters = torch.zeros(self.source.n, self.target.n) self.tc_synaptic = tc_synaptic self.phi = phi self.v_rev = 0 self.cumul_I = None # self.cumul_weigth = self.w.t() # if not hasattr(self.target, "eligibility_trace"): # self.target.eligibility_trace = torch.zeros(*self.w.shape) # self.cumul_et = self.target.eligibility_trace.t() # Get dirac(delta_t) def get_dirac(self): pre_s = self.source.s.view(-1).unsqueeze(1) post_s = self.target.s return torch.max(pre_s, post_s).float( ) # True or 1 if a spike occured either in pre or post neuron, False or 0 otherwise def compute(self, s: torch.Tensor) -> None: # language=rst """ Compute pre-activations of downstream neurons given spikes of upstream neurons. :param s: Incoming spikes. """ # Update of the number of active neurotransmitters for each synapse pre_spike_occured = torch.mul( s.float().view(-1, 1), torch.ones(*self.active_neurotransmitters.shape)) update = -self.active_neurotransmitters / self.tc_synaptic + self.phi * pre_spike_occured update = torch.where(self.w != 0, update, torch.tensor(0.)) self.active_neurotransmitters += update # Get input S = torch.sum(self.active_neurotransmitters.t(), dim=1, keepdim=True).view(1, -1) return (self.v_rev - self.target.v) * torch.max(self.w) * S # if self.cumul_I == None: # self.cumul_I = I # else : # self.cumul_I = torch.cat((self.cumul_I, I),0) # return I def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. Keyword arguments: :param bool learning: Whether to allow connection updates. :param ByteTensor mask: Boolean mask determining which weights to clamp to zero. """ learning = kwargs["learning"] # self.cumul_weigth = torch.cat((self.cumul_weigth, self.w.t()),0) # self.cumul_et = torch.cat((self.cumul_et,self.target.eligibility_trace.t()),0) if learning: self.update_rule.update(**kwargs) mask = kwargs.get("mask", None) if mask is not None: self.w.masked_fill_(mask, 0) def normalize(self) -> None: """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def reset_state_variables(self) -> None: """ Contains resetting logic for the connection. """ pass
class ConcatConnection(AbstractConnection): def __init__(self, source: Dict[str, Nodes], target: Nodes, nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, **kwargs) -> None: super().__init__(source, target, nu, reduction, weight_decay, **kwargs) w = kwargs.get("w", None) source_n = np.sum(nodes.n for nodes in list(source.values())) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: w = torch.clamp(torch.zeros(source_n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.zeros(source_n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, requires_grad=False) self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), requires_grad=False) def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ # Compute multiplication of spike activations by weights and add bias. post = s.float().view(s.size(0), -1) @ self.w + self.b return post.view(s.size(0), *self.target.shape) def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. """ super().update(**kwargs) def normalize(self) -> None: # language=rst """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def reset_state_variables(self) -> None: # language=rst """ Contains resetting logic for the connection. """ super().reset_state_variables()
class Connection(AbstractConnection): # language=rst """ Specifies synapses between one or two populations of neurons. """ def __init__( self, source: Nodes, target: Nodes, impulse_amplitude: float, impulse_length: float, impulse_shape_factor: float = 0.9, invert: bool = False, nu: Optional[Union[float, Sequence[float]]] = None, reduction: Optional[callable] = None, weight_decay: float = 0.0, post_spike_weight_decay: float = 0.0, **kwargs ) -> None: # language=rst """ Instantiates a :code:`Connection` object. :param source: A layer of nodes from which the connection originates. :param target: A layer of nodes to which the connection connects. :param nu: Learning rate for both pre- and post-synaptic events. :param reduction: Method for reducing parameter updates along the minibatch dimension. :param weight_decay: Constant multiple to decay weights by on each iteration. Keyword arguments: :param LearningRule update_rule: Modifies connection parameters according to some rule. :param torch.Tensor w: Strengths of synapses. :param torch.Tensor b: Target population bias. :param float wmin: Minimum allowed value on the connection weights. :param float wmax: Maximum allowed value on the connection weights. :param float norm: Total weight per target neuron normalization constant. :param ByteTensor norm_by_max: Normalize the weight of a neuron by its max weight. :param ByteTensor norm_by_max_with_shadow_weights: Normalize the weight of a neuron by its max weight by original weights. """ super().__init__(source, target, impulse_amplitude, impulse_length, impulse_shape_factor, invert, nu, reduction, weight_decay, post_spike_weight_decay, **kwargs) w = kwargs.get("w", None) if w is None: if self.wmin == -np.inf or self.wmax == np.inf: w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax) else: w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin) else: if self.wmin != -np.inf or self.wmax != np.inf: w = torch.clamp(w, self.wmin, self.wmax) self.w = Parameter(w, False) self.b = Parameter(kwargs.get("b", torch.zeros(target.n)), False) if self.norm_by_max_from_shadow_weights: self.shadow_w = self.w.clone().detach() self.prev_w = self.w.clone().detach() def update_impulse_state(self, s): self.impulse_state += (self.impulse_state > 0).float().view(-1) # adds 1 on where were spikes before s[self.impulse_state.unsqueeze(0) > 0] = 0 self.source.s[self.impulse_state.unsqueeze(0) > 0] = 0 self.impulse_state += (self.impulse_state == 0).float() * s.float().view(-1) # adds 1 on spikes impulse = self.impulse_curve() self.impulse_state *= (self.impulse_state < self.impulse_length).float() return impulse def impulse_curve(self): k = self.impulse_shape_factor if self.invert: impulse_value_2 = self.impulse_amplitude/(self.impulse_length*k - 1) #производная в точках, не находящихся в середине импульса impulse_value_1 = self.impulse_amplitude/(self.impulse_length*(1-k)) impulse_bias = 2*self.impulse_amplitude *(self.impulse_state > (self.impulse_length * (1-k)+ 0.5)).float().view(-1) *(self.impulse_state <= (self.impulse_length * (1-k)+1.5)).float().view(-1) impulse = (-impulse_value_1) * (self.impulse_state > 0).float().view(-1) * (self.impulse_state <= (self.impulse_length * (1-k) + 0.5)).float().view(-1) + (-impulse_value_2) * (self.impulse_state > (self.impulse_length * (1-k)+1.5)).float().view(-1) + impulse_bias return impulse else: impulse_value = self.impulse_amplitude/(self.impulse_length - 2 )/k #производная в точках, не находящихся в середине импульса impulse_bias = (2*self.impulse_amplitude *(abs(self.impulse_state - (self.impulse_length * k)) < 0.5).float().view(-1) + 2*self.impulse_amplitude*(abs(self.impulse_state - (self.impulse_length * k)) == 0.5).float().view(-1)*(self.impulse_state < self.impulse_length * k).float().view(-1)) impulse = impulse_value * (self.impulse_state > 0).float().view(-1) * (self.impulse_state < self.impulse_length * k).float().view(-1) + impulse_value/((1-k)/k) * (self.impulse_state > self.impulse_length * k).float().view(-1) - impulse_bias return impulse def compute(self, s: torch.Tensor) -> torch.Tensor: # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ # Compute multiplication of spike activations by weights and add bias. # language=rst """ Compute pre-activations given spikes using connection weights. :param s: Incoming spikes. :return: Incoming spikes multiplied by synaptic weights (with or without decaying spike activation). """ impulse = self.update_impulse_state(s) self.a_pre += impulse # self.a_pre *= (self.impulse_state > 0).float() # # Compute multiplication of spike activations by connection weights. a_post = self.a_pre @ self.w return a_post.view(*self.target.shape) def update(self, **kwargs) -> None: # language=rst """ Compute connection's update rule. """ super().update(**kwargs) def normalize(self) -> None: # language=rst """ Normalize weights so each target neuron has sum of connection weights equal to ``self.norm``. """ if self.norm is not None: w_abs_sum = self.w.abs().sum(0).unsqueeze(0) w_abs_sum[w_abs_sum == 0] = 1.0 self.w *= self.norm / w_abs_sum def normalize_by_max(self) -> None: # language=rst """ Normalize weights by the max weight of the target neuron. """ if self.norm_by_max: w_max = self.w.abs().max(0)[0] w_max[w_max == 0] = 1.0 self.w /= w_max def normalize_by_max_from_shadow_weights(self) -> None: # language=rst """ Normalize weights by the max weight of the target neuron. """ if self.norm_by_max_from_shadow_weights: self.shadow_w += self.w - self.prev_w w_max = self.shadow_w.abs().max(0)[0] w_max[w_max == 0] = 1.0 self.w = self.shadow_w / w_max self.prev_w = self.w.clone().detach() def reset_(self) -> None: # language=rst """ Contains resetting logic for the connection. """ super().reset_() self.a_pre = torch.zeros_like(self.a_pre) self.impulse_state = torch.zeros_like(self.impulse_state)
class Gaussian2d(nn.Module): """ Instantiates an object that can used to learn a point in the core feature space for each neuron, sampled from a Gaussian distribution with some mean and variance at train but set to mean at test time, that best predicts its response. The readout receives the shape of the core as 'in_shape', the number of units/neurons being predicted as 'outdims', 'bias' specifying whether or not bias term is to be used and 'init_range' range for initialising the mean and variance of the gaussian distribution from which we sample to uniform distribution, U(-init_range,init_range) and uniform distribution, U(0.0, 3*init_range) respectively. The grid parameter contains the normalized locations (x, y coordinates in the core feature space) and is clipped to [-1.1] as it a requirement of the torch.grid_sample function. The feature parameter learns the best linear mapping between the feature map from a given location, sample from Gaussian at train time but set to mean at eval time, and the unit's response with or without an additional elu non-linearity. Args: in_shape (list): shape of the input feature map [channels, width, height] outdims (int): number of output units bias (bool): adds a bias term init_mu_range (float): initialises the the mean with Uniform([-init_range, init_range]) [expected: positive value <=1] init_sigma_range (float): initialises sigma with Uniform([0.0, init_sigma_range]). It is recommended however to use a fixed initialization, for faster convergence. For this, set fixed_sigma to True. batch_sample (bool): if True, samples a position for each image in the batch separately [default: True as it decreases convergence time and performs just as well] align_corners (bool): Keyword agrument to gridsample for bilinear interpolation. It changed behavior in PyTorch 1.3. The default of align_corners = True is setting the behavior to pre PyTorch 1.3 functionality for comparability. fixed_sigma (bool). Recommended behavior: True. But set to false for backwards compatibility. If true, initialized the sigma not in a range, but with the exact value given for all neurons. """ def __init__(self, in_shape, outdims, bias, init_mu_range=0.5, init_sigma_range=0.5, batch_sample=True, align_corners=True, fixed_sigma=False, **kwargs): warnings.warn( "Gaussian2d is deprecated and will be removed in the future. Use `layers.readout.NonIsoGaussian2d` instead", DeprecationWarning, ) super().__init__() if init_mu_range > 1.0 or init_mu_range <= 0.0 or init_sigma_range <= 0.0: raise ValueError( "either init_mu_range doesn't belong to [0.0, 1.0] or init_sigma_range is non-positive" ) self.in_shape = in_shape c, w, h = in_shape self.outdims = outdims self.batch_sample = batch_sample self.grid_shape = (1, outdims, 1, 2) self.mu = Parameter(torch.Tensor( *self.grid_shape)) # mean location of gaussian for each neuron self.sigma = Parameter( torch.Tensor(*self.grid_shape )) # standard deviation for gaussian for each neuron self.features = Parameter(torch.Tensor( 1, c, 1, outdims)) # feature weights for each channel of the core if bias: bias = Parameter(torch.Tensor(outdims)) self.register_parameter("bias", bias) else: self.register_parameter("bias", None) self.init_mu_range = init_mu_range self.init_sigma_range = init_sigma_range self.align_corners = align_corners self.fixed_sigma = fixed_sigma self.initialize() def initialize(self): """ Initializes the mean, and sigma of the Gaussian readout along with the features weights """ self.mu.data.uniform_(-self.init_mu_range, self.init_mu_range) if self.fixed_sigma: self.sigma.data.uniform_(self.init_sigma_range, self.init_sigma_range) else: self.sigma.data.uniform_(0, self.init_sigma_range) warnings.warn( "sigma is sampled from uniform distribuiton, instead of a fixed value. Consider setting " "fixed_sigma to True") self.features.data.fill_(1 / self.in_shape[0]) if self.bias is not None: self.bias.data.fill_(0) def sample_grid(self, batch_size, sample=None): """ Returns the grid locations from the core by sampling from a Gaussian distribution Args: batch_size (int): size of the batch sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma), defined per neuron or use the mean, mu, of the Gaussian distribution without sampling. if sample is None (default), samples from the N(mu,sigma) during training phase and fixes to the mean, mu, during evaluation phase. if sample is True/False, overrides the model_state (i.e training or eval) and does as instructed """ with torch.no_grad(): self.mu.clamp_( min=-1, max=1 ) # at eval time, only self.mu is used so it must belong to [-1,1] self.sigma.clamp_( min=0) # sigma/variance is always a positive quantity grid_shape = (batch_size, ) + self.grid_shape[1:] sample = self.training if sample is None else sample if sample: norm = self.mu.new(*grid_shape).normal_() else: norm = self.mu.new( *grid_shape).zero_() # for consistency and CUDA capability return torch.clamp( norm * self.sigma + self.mu, min=-1, max=1 ) # grid locations in feature space sampled randomly around the mean self.mu @property def grid(self): return self.sample_grid(batch_size=1, sample=False) def feature_l1(self, average=True): """ Returns the l1 regularization term either the mean or the sum of all weights Args: average(bool): if True, use mean of weights for regularization """ if average: return self.features.abs().mean() else: return self.features.abs().sum() def forward(self, x, sample=None, shift=None, out_idx=None): """ Propagates the input forwards through the readout Args: x: input data sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma), defined per neuron or use the mean, mu, of the Gaussian distribution without sampling. if sample is None (default), samples from the N(mu,sigma) during training phase and fixes to the mean, mu, during evaluation phase. if sample is True/False, overrides the model_state (i.e training or eval) and does as instructed shift (bool): shifts the location of the grid (from eye-tracking data) out_idx (bool): index of neurons to be predicted Returns: y: neuronal activity """ N, c, w, h = x.size() c_in, w_in, h_in = self.in_shape if (c_in, w_in, h_in) != (c, w, h): raise ValueError( "the specified feature map dimension is not the readout's expected input dimension" ) feat = self.features.view(1, c, self.outdims) bias = self.bias outdims = self.outdims if self.batch_sample: # sample the grid_locations separately per image per batch grid = self.sample_grid( batch_size=N, sample=sample) # sample determines sampling from Gaussian else: # use one sampled grid_locations for all images in the batch grid = self.sample_grid(batch_size=1, sample=sample).expand(N, outdims, 1, 2) if out_idx is not None: if isinstance(out_idx, np.ndarray): if out_idx.dtype == bool: out_idx = np.where(out_idx)[0] feat = feat[:, :, out_idx] grid = grid[:, out_idx] if bias is not None: bias = bias[out_idx] outdims = len(out_idx) if shift is not None: grid = grid + shift[:, None, None, :] y = F.grid_sample(x, grid, align_corners=self.align_corners) y = (y.squeeze(-1) * feat).sum(1).view(N, outdims) if self.bias is not None: y = y + bias return y def __repr__(self): c, w, h = self.in_shape r = self.__class__.__name__ + " (" + "{} x {} x {}".format( c, w, h) + " -> " + str(self.outdims) + ")" if self.bias is not None: r += " with bias" for ch in self.children(): r += " -> " + ch.__repr__() + "\n" return r