class _ConvBnNd(nn.modules.conv._ConvNd): _version = 2 def __init__( self, # ConvNd args in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, # BatchNormNd args # num_features: out_channels eps=1e-05, momentum=0.1, # affine: True # track_running_stats: True # Args for this module freeze_bn=False, qconfig=None): nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, False, padding_mode) assert qconfig, 'qconfig must be provided for QAT module' self.frozen = freeze_bn if self.training else True self.bn = nn.BatchNorm2d(out_channels, eps, momentum, True, True) self.weight_quantizer = qconfig.weight self.bias_quantizer = qconfig.bias if bias: self.bias = Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_bn_parameters() # this needs to be called after reset_bn_parameters, # as they modify the same state if self.training: if freeze_bn: self.freeze_bn() else: self.update_bn() else: self.freeze_bn() def reset_running_stats(self): self.bn.reset_running_stats() def reset_bn_parameters(self): self.bn.reset_running_stats() init.uniform_(self.bn.weight) init.zeros_(self.bn.bias) # note: below is actully for conv, not BN if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def batch_stats(self, x, bias=None): """Get the batch mean and variance of x and updates the BatchNorm's running mean and average. Args: x (torch.Tensor): input batch. bias (torch.Tensor): the bias that is to be applied to the batch. Returns: (mean, variance) Note: In case of `nn.Linear`, x may be of shape (N, C, L) or (N, L) where N is batch size, C is number of channels, L is the features size. The batch norm computes the stats over C in the first case or L on the second case. The batch normalization layer is (`nn.BatchNorm1d`)[https://pytorch.org/docs/stable/nn.html#batchnorm1d] In case of `nn.Conv2d`, x is of shape (N, C, H, W) where H,W are the image dimensions, and the batch norm computes the stats over C. The batch normalization layer is (`nn.BatchNorm2d`)[https://pytorch.org/docs/stable/nn.html#batchnorm2d] """ channel_size = self.bn.num_features self.bn.num_batches_tracked += 1 # Calculate current batch stats batch_mean = x.transpose(0, 1).contiguous().view(channel_size, -1).mean(1) # BatchNorm currently uses biased variance (without Bessel's correction) as was discussed at # https://github.com/pytorch/pytorch/issues/1410 # # also see the source code itself: # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L216 batch_var = x.transpose(0, 1).contiguous().view(channel_size, -1).var(1, unbiased=False) # Update running stats with torch.no_grad(): biased_batch_mean = batch_mean + (bias if bias is not None else 0) # However - running_var is updated using unbiased variance! # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L223 n = x.numel() / channel_size corrected_var = batch_var * (n / float(n - 1)) momentum = self.bn.momentum if momentum is None: # momentum is None - we compute a cumulative moving average # as noted in https://pytorch.org/docs/stable/nn.html#batchnorm2d momentum = 1. / float(self.bn.num_batches_tracked) self.bn.running_mean.mul_(1 - momentum).add_(momentum * biased_batch_mean) self.bn.running_var.mul_(1 - momentum).add_(momentum * corrected_var) return batch_mean, batch_var def reset_parameters(self): super(_ConvBnNd, self).reset_parameters() def update_bn(self): self.frozen = False self.bn.training = True return self def freeze_bn(self): if self.frozen: return with torch.no_grad(): # The same implementation as nndct_shared/optimzation/fuse_conv_bn.py # is used so that the test accruacy is same as the deployable model. gamma = self.bn.weight.detach().cpu().numpy() beta = self.bn.bias.detach().cpu().numpy() running_var = self.bn.running_var.detach().cpu().numpy() running_mean = self.bn.running_mean.detach().cpu().numpy() epsilon = self.bn.eps scale = gamma / np.sqrt(running_var + epsilon) offset = beta - running_mean * scale weight = self.weight.detach().cpu().numpy() weight = np.multiply(weight.transpose(1, 2, 3, 0), scale).transpose(3, 0, 1, 2) self.weight.copy_(torch.from_numpy(weight)) bias = self.bias.detach.cpu().numpy( ) if self.bias is not None else 0 bias = torch.from_numpy(bias * scale + offset) if self.bias is not None: self.bias.copy_(bias) else: self.bias = nn.Parameter(bias) self.frozen = True self.bn.training = False return def broadcast_correction(self, c: torch.Tensor): """Broadcasts a correction factor to the output for elementwise operations.""" expected_output_dim = 4 view_fillers_dim = expected_output_dim - c.dim() - 1 view_filler = (1, ) * view_fillers_dim expected_view_shape = c.shape + view_filler return c.view(*expected_view_shape) def broadcast_correction_weight(self, c): """Broadcasts a correction factor to the weight.""" if c.dim() != 1: raise ValueError( "Correction factor needs to have a single dimension") expected_weight_dim = 4 view_fillers_dim = expected_weight_dim - c.dim() view_filler = (1, ) * view_fillers_dim expected_view_shape = c.shape + view_filler return c.view(*expected_view_shape) def extra_repr(self): return super(_ConvBnNd, self).extra_repr() def forward(self, x): gamma, beta = self.bn.weight, self.bn.bias if self.frozen: quantized_weight = self.weight_quantizer(self.weight) quantized_bias = self.bias_quantizer(self.bias) return self._conv_forward(x, quantized_weight, quantized_bias) if self.training: batch_mean, batch_var = self.batch_stats( self._conv_forward(x, self.weight), self.bias) recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps) with torch.no_grad(): sigma_running = torch.sqrt(self.bn.running_var + self.bn.eps) w_corrected = self.weight * self.broadcast_correction_weight( gamma / sigma_running) w_quantized = self.weight_quantizer(w_corrected) recip_c = self.broadcast_correction(sigma_running * recip_sigma_batch) bias_corrected = beta - gamma * batch_mean * recip_sigma_batch bias_quantized = self.broadcast_correction( self.bias_quantizer(bias_corrected)) y = self._conv_forward(x, w_quantized, None) y.mul_(recip_c).add_(bias_quantized) else: with torch.no_grad(): recip_sigma_running = torch.rsqrt(self.bn.running_var + self.bn.eps) w_corrected = self.weight * self.broadcast_correction_weight( gamma * recip_sigma_running) w_quantized = self.weight_quantizer(w_corrected) corrected_mean = self.bn.running_mean - (self.bias if self.bias is not None else 0) bias_corrected = beta - gamma * corrected_mean * recip_sigma_running bias_quantized = self.bias_quantizer(bias_corrected) y = self._conv_forward(x, w_quantized, bias_quantized) #print('w_quantized:', w_quantized.sum()) #print('bias_quantized:', bias_quantized.sum()) #print('conv2d output:', y.sum()) return y def train(self, mode=True): """Batchnorm's training behavior is using the self.training flag. Prevent changing it if BN is frozen. This makes sure that calling `model.train()` on a model with a frozen BN will behave properly. """ self.training = mode if not self.frozen: for module in self.children(): module.train(mode) return self # ===== Serialization version history ===== # # Version 1/None # self # |--- weight : Tensor # |--- bias : Tensor # |--- gamma : Tensor # |--- beta : Tensor # |--- running_mean : Tensor # |--- running_var : Tensor # |--- num_batches_tracked : Tensor # # Version 2 # self # |--- weight : Tensor # |--- bias : Tensor # |--- bn : Module # |--- weight : Tensor (moved from v1.self.gamma) # |--- bias : Tensor (moved from v1.self.beta) # |--- running_mean : Tensor (moved from v1.self.running_mean) # |--- running_var : Tensor (moved from v1.self.running_var) # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version = local_metadata.get('version', None) if version is None or version == 1: # BN related parameters and buffers were moved into the BN module for v2 v2_to_v1_names = { 'bn.weight': 'gamma', 'bn.bias': 'beta', 'bn.running_mean': 'running_mean', 'bn.running_var': 'running_var', 'bn.num_batches_tracked': 'num_batches_tracked', } for v2_name, v1_name in v2_to_v1_names.items(): if prefix + v1_name in state_dict: state_dict[prefix + v2_name] = state_dict[prefix + v1_name] state_dict.pop(prefix + v1_name) elif strict: missing_keys.append(prefix + v2_name) super(_ConvBnNd, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @classmethod def from_float(cls, conv, bn, qconfig): """Create a qat module from a float module.""" assert qconfig, 'Input float module must have a valid qconfig' convbn = cls(conv.in_channels, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode, bn.eps, bn.momentum, False, qconfig) convbn.weight = conv.weight convbn.bias = conv.bias convbn.bn.weight = bn.weight convbn.bn.bias = bn.bias convbn.bn.running_mean = bn.running_mean convbn.bn.running_var = bn.running_var convbn.bn.num_batches_tracked = bn.num_batches_tracked convbn.bn.eps = bn.eps return convbn
class Convolution(nn.Module): r"""Performs a 2D convolution over an input spike-wave composed of several input planes. Current version only supports stride of 1 with no padding. The input is a 4D tensor with the size :math:`(T, C_{{in}}, H_{{in}}, W_{{in}})` and the crresponsing output is of size :math:`(T, C_{{out}}, H_{{out}}, W_{{out}})`, where :math:`T` is the number of time steps, :math:`C` is the number of feature maps (channels), and :math:`H`, and :math:`W` are the hight and width of the input/output planes. * :attr:`in_channels` controls the number of input planes (channels/feature maps). * :attr:`out_channels` controls the number of feature maps in the current layer. * :attr:`kernel_size` controls the size of the convolution kernel. It can be a single integer or a tuple of two integers. * :attr:`weight_mean` controls the mean of the normal distribution used for initial random weights. * :attr:`weight_std` controls the standard deviation of the normal distribution used for initial random weights. .. note:: Since this version of convolution does not support padding, it is the user responsibility to add proper padding on the input before applying convolution. Args: in_channels (int): Number of channels in the input. out_channels (int): Number of channels produced by the convolution. kernel_size (int or tuple): Size of the convolving kernel. weight_mean (float, optional): Mean of the initial random weights. Default: 0.8 weight_std (float, optional): Standard deviation of the initial random weights. Default: 0.02 """ def __init__(self, in_channels, out_channels, kernel_size, weight_mean=0.8, weight_std=0.02): super(Convolution, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = to_pair(kernel_size) #self.weight_mean = weight_mean #self.weight_std = weight_std # For future use self.stride = 1 self.bias = None self.dilation = 1 self.groups = 1 self.padding = 0 # Parameters self.weight = Parameter( torch.Tensor(self.out_channels, self.in_channels, *self.kernel_size)) self.weight.requires_grad_(False) # We do not use gradients self.reset_weight(weight_mean, weight_std) print(self.weight.shape) def reset_weight(self, weight_mean=0.8, weight_std=0.02): """Resets weights to random values based on a normal distribution. Args: weight_mean (float, optional): Mean of the random weights. Default: 0.8 weight_std (float, optional): Standard deviation of the random weights. Default: 0.02 """ self.weight.normal_(weight_mean, weight_std) def load_weight(self, target): """Loads weights with the target tensor. Args: target (Tensor=): The target tensor. """ self.weight.copy_(target) def forward(self, input): return fn.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class PopArt(Module): """PopArt http://papers.nips.cc/paper/6076-learning-values-across-many-orders-of-magnitude""" def __init__(self, output_layer, beta: float = 0.0003, zero_debias: bool = True, start_pop: int = 8): # zero_debias=True and start_pop=8 seem to improve things a little but (False, 0) works as well super().__init__() self.start_pop = start_pop self.beta = beta self.zero_debias = zero_debias self.output_layers = output_layer if isinstance( output_layer, (tuple, list, torch.nn.ModuleList)) else (output_layer, ) shape = self.output_layers[0].bias.shape device = self.output_layers[0].bias.device assert all(shape == x.bias.shape for x in self.output_layers) self.mean = Parameter(torch.zeros(shape, device=device), requires_grad=False) self.mean_square = Parameter(torch.ones(shape, device=device), requires_grad=False) self.std = Parameter(torch.ones(shape, device=device), requires_grad=False) self.updates = 0 @torch.no_grad() def update(self, targets): beta = max(1 / (self.updates + 1), self.beta) if self.zero_debias else self.beta # note that for beta = 1/self.updates the resulting mean, std would be the true mean and std over all past data new_mean = (1 - beta) * self.mean + beta * targets.mean(0) new_mean_square = (1 - beta) * self.mean_square + beta * ( targets * targets).mean(0) new_std = (new_mean_square - new_mean * new_mean).sqrt().clamp( 0.0001, 1e6) assert self.std.shape == (1, ), 'this has only been tested in 1D' if self.updates >= self.start_pop: for layer in self.output_layers: # TODO: Properly apply PopArt in RTAC and remove the hack below # We modify the weight while it's gradient is being computed # Therefore we have to use .data (Pytorch would otherwise throw an error) layer.weight *= self.std / new_std layer.bias *= self.std layer.bias += self.mean - new_mean layer.bias /= new_std self.mean.copy_(new_mean) self.mean_square.copy_(new_mean_square) self.std.copy_(new_std) self.updates += 1 return self.normalize(targets) def normalize(self, x): return (x - self.mean) / self.std def unnormalize(self, value): return value * self.std + self.mean
class BaseNeuron(nn.Module): r"""Base neuron model, is a container to define basic neuron functionalties. Defines basic spiking, voltage and trace characteristics. Just has to adhere to the API functionalities to integrate within Connection modules. Make sure the Neuron class receives input voltage for each neuron and returns a Tensor indicating which neurons have spiked. :param cells_shape: a list or tuple that specifies the shape of the neurons in the conventional PyTorch format, but with the batch size as the first dimension. :param thresh: spiking threshold, when the cells' voltage surpasses this value it generates a spike. :param v_rest: voltage resting value, the :class:`Neuron` will default back to this over time or after spiking. :param dt: duration of a single timestep. :param duration_refrac: Number of timesteps the :class:`Neuron` is dormant after spiking. Make sure ``dt`` fits an integer number of times in ``duration refrac``. :param update_type: string, either ``'linear'`` or ``'exponential'``, default is ``'linear'``. :param store_trace: ``Boolean`` flag to store the complete spiking history, defaults to ``False``. """ def __init__(self, cells_shape, thresh, v_rest, dt, duration_refrac, store_trace=False): super(BaseNeuron, self).__init__() # Check compatibility of dt and refrac counting assert ( duration_refrac % dt == 0 ), "dt does not fit an integer amount of times in duration_refrac." assert duration_refrac >= 0, "duration_refrac should be non-negative." # Fixed parameters self.register_buffer("v_rest", torch.tensor(v_rest, dtype=torch.float)) self.register_buffer("dt", torch.tensor(dt, dtype=torch.float)) self.register_buffer("duration_refrac", torch.tensor(duration_refrac, dtype=torch.float)) self.register_buffer("thresh_center", torch.tensor(thresh, dtype=torch.float)) # Define dynamic parameters self.register_buffer("spikes", torch.empty(*cells_shape, dtype=torch.bool)) self.register_buffer("v_cell", torch.empty(*cells_shape, dtype=torch.float)) self.register_buffer("trace", torch.empty(*cells_shape, dtype=torch.float)) self.register_buffer("refrac_counts", torch.empty(*cells_shape, dtype=torch.float)) # Define learnable parameters self.thresh = Parameter(torch.empty(*cells_shape, dtype=torch.float), requires_grad=False) # In case of storing a complete, local copy of the activity of a neuron if store_trace: complete_trace = torch.zeros(*cells_shape, 1, dtype=torch.bool) else: complete_trace = None self.register_buffer("complete_trace", complete_trace) def spiking(self): r"""Return cells that are in spiking state.""" self.spikes.copy_(self.v_cell >= self.thresh) return self.spikes.clone() def refrac(self, spikes): r"""Basic counting version of cell refractory period. Can be overwritten in case of the need of more refined functionality. """ if self.duration_refrac > 0: self.refrac_counts[self.refrac_counts > 0] -= self.dt self.refrac_counts += self.duration_refrac * self.convert_spikes( spikes) self.v_cell.masked_fill_(spikes, self.v_rest) def concat_trace(self, x): r"""Concatenate most recent timestep to the trace storage.""" self.complete_trace = torch.cat( [self.complete_trace, x.unsqueeze(-1)], dim=-1) def fold(self, x): r"""Fold incoming spike train by summing last dimension.""" if isinstance(x, (list, tuple)): x = torch.cat(x, dim=-1) return x.sum(-1) def unfold(self, x): r"""Move the last dimension (all incoming to single neuron in current layer) to first dim. This is done because PyTorch broadcasting does not support broadcasting over the last dim. """ shape = x.shape return x.view(shape[-1], *shape[:-1]) def convert_spikes(self, spikes): r"""Cast ``torch.bool`` spikes to datatype that is used for voltage and weights""" return spikes.to(self.v_cell.dtype) def reset_state(self): r"""Reset cell states that accumulate over time during simulation.""" self.v_cell.fill_(self.v_rest) self.spikes.fill_(False) self.refrac_counts.fill_(0) self.trace.fill_(0) if self.complete_trace is not None: self.complete_trace = torch.zeros( *self.v_cell.shape, 1, device=self.v_cell.device).bool() def reset_thresh(self): r"""Reset threshold to initialization values, allows for different standard thresholds per neuron.""" self.thresh.copy_(torch.ones_like(self.thresh) * self.thresh_center) def no_grad(self): r"""Turn off learning and gradient storing.""" _set_no_grad(self) def init_neuron(self): r"""Initialize state, parameters, and turn off gradients.""" self.no_grad() self.reset_state() self.reset_thresh() def forward(self, x): raise NotImplementedError("Neurons must implement `forward`") def update_trace(self, x): r"""Placeholder for trace update function.""" raise NotImplementedError("Neurons must implement `update_trace`") def update_voltage(self, x): r"""Placeholder for voltage update function.""" raise NotImplementedError("Neurons must implement `update_voltage`")
class DMCell(nn.Module): def __init__(self, inp_num = 5, hid_num = 2, Je = 8., Jm = -2, I0 = 0.0, dt = 1., taus = 100., gamma = 0.1, target_mode="x_target", learning_rule = "force", activation = LogAct(), rec_activation = RecLogAct()): super().__init__() self.hid_num = hid_num self.inp_num = inp_num self.Je = Je self.Jm = Jm self.I0 = I0 self.alpha = dt/taus self.gamma = gamma self.win = Parameter(torch.Tensor(hid_num,inp_num)) self.wr = Parameter(torch.Tensor(hid_num,hid_num)) self.act = activation self.rec_act = rec_activation self.learning_rule = learning_rule self.target_mode = target_mode self.init_weights() def init_weights(self): stdv = 1.0 / math.sqrt(self.hid_num) # stdv = 0.5 if self.learning_rule == "force": self.win.data = torch.zeros((self.hid_num,self.inp_num)) else: # self.win.data.uniform_(-stdv, stdv) self.win.data = torch.zeros((self.hid_num,self.inp_num)) wr = np.ones((self.hid_num,self.hid_num))*self.Jm wr = wr+np.eye(self.hid_num)*self.Je - np.eye(self.hid_num)*self.Jm self.wr.data = torch.FloatTensor(wr) self.wr.requires_grad = False def apply_win(self,w): assert torch.Size(w.shape) == self.win.shape, "w shape should be same, but got {}.format"(w.shape) self.win.data = torch.FloatTensor(w) def forward(self,x,hid,y=None): """ learning_rule is "force" or "bp" """ if y is None: s = hid[0] # pdb.set_trace() rx = F.linear(x,self.win) + self.I0 + F.linear(s,self.wr) r = self.act(rx) s_new = s + self.alpha*(-s + (1.-s)*self.gamma*r) if self.target_mode == "x_target": return rx, (s_new,) else: return r, (s_new,) elif y is not None and self.learning_rule == "force": if self.target_mode == "x_target": y = y else: y = self.rec_act(y) batch_size = x.shape[0] s,P = hid rx = F.linear(x,self.win) + self.I0 + F.linear(s,self.wr) err = rx - y r = x k_fenmu = F.linear(r, P) rPr = torch.sum(k_fenmu * r, 1, True) k_fenzi = 1.0 /(1.0 + rPr) k = k_fenmu * k_fenzi kall = k[:,:,None].repeat(1, 1, self.hid_num) # kall = torch.repeat(k[:, :, None], (1, 1, self.hid_num)) dw = -kall * err[:, None, :] self.win.copy_(self.win + torch.mean(dw, 0).transpose(1,0)) # pdb.set_trace() P = P - F.linear(k.t(), k_fenmu.t())/batch_size # r = self.act(rx) s_new = s + self.alpha*(-s + (1.-s)*self.gamma*r) return err,r,(s_new, P) else: raise ValueError("No such inference or training configuration in the Decision Network !")