def _ann_to_snn_helper(prev, current, nxt): if isinstance(current, nn.Linear): layer = SubtractiveResetIFNodes(n=current.out_features, reset=0, thresh=1, refrac=0) connection = topology.Connection( source=prev, target=layer, w=current.weight.t(), b=current.bias ) elif isinstance(current, nn.Conv2d): input_height, input_width = prev.shape[2], prev.shape[3] out_channels, output_height, output_width = current.out_channels, prev.shape[2], prev.shape[3] width = (input_height - current.kernel_size[0] + 2 * current.padding[0]) / current.stride[0] + 1 height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1 shape = (1, out_channels, int(width), int(height)) layer = SubtractiveResetIFNodes( shape=shape, reset=0, thresh=1, refrac=0 ) connection = topology.Conv2dConnection( source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride, padding=current.padding, dilation=current.dilation, w=current.weight, b=current.bias ) elif isinstance(current, nn.MaxPool2d): input_height, input_width = prev.shape[2], prev.shape[3] current.kernel_size = _pair(current.kernel_size) current.padding = _pair(current.padding) current.stride = _pair(current.stride) width = (input_height - current.kernel_size[0] + 2 * current.padding[0]) / current.stride[0] + 1 height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1 shape = (1, prev.shape[1], int(width), int(height)) layer = PassThroughNodes( shape=shape ) connection = topology.MaxPool2dConnection( source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride, padding=current.padding, dilation=current.dilation, decay=1 ) else: return None, None return layer, connection
def _ann_to_snn_helper(prev, current, node_type, **kwargs): # language=rst """ Helper function for main ``ann_to_snn`` method. :param prev: Previous PyTorch module in artificial neural network. :param current: Current PyTorch module in artificial neural network. :return: Spiking neural network layer and connection corresponding to ``prev`` and ``current`` PyTorch modules. """ if isinstance(current, nn.Linear): layer = node_type(n=current.out_features, reset=0, thresh=1, refrac=0, **kwargs) bias = current.bias if current.bias is not None else torch.zeros( layer.n) connection = topology.Connection(source=prev, target=layer, w=current.weight.t(), b=bias) elif isinstance(current, nn.Conv2d): input_height, input_width = prev.shape[2], prev.shape[3] out_channels, output_height, output_width = ( current.out_channels, prev.shape[2], prev.shape[3], ) width = (input_height - current.kernel_size[0] + 2 * current.padding[0]) / current.stride[0] + 1 height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1 shape = (1, out_channels, int(width), int(height)) layer = node_type(shape=shape, reset=0, thresh=1, refrac=0, **kwargs) bias = current.bias if current.bias is not None else torch.zeros( layer.shape[1]) connection = topology.Conv2dConnection( source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride, padding=current.padding, dilation=current.dilation, w=current.weight, b=bias, ) elif isinstance(current, nn.MaxPool2d): input_height, input_width = prev.shape[2], prev.shape[3] current.kernel_size = _pair(current.kernel_size) current.padding = _pair(current.padding) current.stride = _pair(current.stride) width = (input_height - current.kernel_size[0] + 2 * current.padding[0]) / current.stride[0] + 1 height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1 shape = (1, prev.shape[1], int(width), int(height)) layer = PassThroughNodes(shape=shape) connection = topology.MaxPool2dConnection( source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride, padding=current.padding, dilation=current.dilation, decay=1, ) elif isinstance(current, Permute): layer = PassThroughNodes(shape=[ prev.shape[current.dims[0]], prev.shape[current.dims[1]], prev.shape[current.dims[2]], prev.shape[current.dims[3]], ]) connection = PermuteConnection(source=prev, target=layer, dims=current.dims) elif isinstance(current, nn.ConstantPad2d): layer = PassThroughNodes(shape=[ prev.shape[0], prev.shape[1], current.padding[0] + current.padding[1] + prev.shape[2], current.padding[2] + current.padding[3] + prev.shape[3], ]) connection = ConstantPad2dConnection(source=prev, target=layer, padding=current.padding) else: return None, None return layer, connection
time = 1000 # Create network object. network = Network() # Create input and output groups of neurons. input_group = nodes.Input(n=n_input) # 100 input nodes. output_group = nodes.LIFNodes(n=n_output) # 500 output nodes. network.add_layer(input_group, name='input') network.add_layer(output_group, name='output') # Input -> output connection. # Unit Gaussian feed-forward weights. w = torch.randn(n_input, n_output) forward_conn = topology.Connection(input_group, output_group, w=w) # Output -> output connection. # Random, inhibitory recurrent weights. w = torch.bernoulli(torch.rand(n_output, n_output)) - torch.diag(torch.ones(n_output)) recurrent_conn = topology.Connection(output_group, output_group, w=w) network.add_connection(forward_conn, source='input', target='output') network.add_connection(recurrent_conn, source='output', target='output') # Monitor input and output spikes during the simulation. for l in network.layers: monitor = monitors.Monitor(network.layers[l], state_vars=['s'], time=time) network.add_monitor(monitor, name=l) # Create input ~ Bernoulli(0.1) for 1,000 timesteps.
def _ann_to_snn_helper(prev, current, scale): # language=rst """ Helper function for main ``ann_to_snn`` method. :param prev: Previous PyTorch module in artificial neural network. :param current: Current PyTorch module in artificial neural network. :return: Spiking neural network layer and connection corresponding to ``prev`` and ``current`` PyTorch modules. """ if isinstance(current, nn.Linear): layer = LIFNodes(n=current.out_features, refrac=0, traces=True, thresh=-52, rest=-65.0, decay=1e-2) connection = topology.Connection(source=prev, target=layer, w=current.weight.t() * scale) elif isinstance(current, nn.Conv2d): input_height, input_width = prev.shape[2], prev.shape[3] out_channels, output_height, output_width = current.out_channels, prev.shape[ 2], prev.shape[3] width = (input_height - current.kernel_size[0] + 2 * current.padding[0]) / current.stride[0] + 1 height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1 shape = (1, out_channels, int(width), int(height)) layer = LIFNodes( shape=shape, refrac=0, traces=True, thresh=-52, rest=-65.0, decay=1e-2, ) connection = topology.Conv2dConnection(source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride, padding=current.padding, dilation=current.dilation, w=current.weight * scale) elif isinstance(current, Permute): layer = PassThroughNodes(shape=[ prev.shape[current.dims[0]], prev.shape[current.dims[1]], prev.shape[current.dims[2]], prev.shape[current.dims[3]] ]) connection = PermuteConnection(source=prev, target=layer, dims=current.dims) elif isinstance(current, nn.ConstantPad2d): layer = PassThroughNodes(shape=[ prev.shape[0], prev.shape[1], current.padding[0] + current.padding[1] + prev.shape[2], current.padding[2] + current.padding[3] + prev.shape[3] ]) connection = ConstantPad2dConnection(source=prev, target=layer, padding=current.padding) else: return None, None return layer, connection