def ann_to_snn( ann: Union[nn.Module, str], input_shape: Sequence[int], data: Optional[torch.Tensor] = None, percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes, **kwargs, ) -> Network: # language=rst """ Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural network. :param ann: Artificial neural network implemented in PyTorch. Accepts either ``torch.nn.Module`` or path to network saved using ``torch.save()``. :param input_shape: Shape of input data. :param data: Data to use to perform data-based weight normalization of shape ``[n_examples, ...]``. :param percentile: Percentile (in ``[0, 100]``) of activations to scale by in data-based normalization scheme. :param node_type: Class of ``Nodes`` to use in replacing ``torch.nn.Linear`` layers in original ANN. :return: Spiking neural network implemented in PyTorch. """ if isinstance(ann, str): ann = torch.load(ann) else: ann = deepcopy(ann) assert isinstance(ann, nn.Module) if data is None: import warnings warnings.warn("Data is None. Weights will not be scaled.", RuntimeWarning) else: ann = data_based_normalization(ann=ann, data=data.detach(), percentile=percentile) snn = Network() input_layer = nodes.Input(shape=input_shape) snn.add_layer(input_layer, name="Input") children = [] for c in ann.children(): if isinstance(c, nn.Sequential): for c2 in list(c.children()): children.append(c2) else: children.append(c) i = 0 prev = input_layer while i < len(children) - 1: current, nxt = children[i:i + 2] layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs) i += 1 if layer is None or connection is None: continue snn.add_layer(layer, name=str(i)) snn.add_connection(connection, source=str(i - 1), target=str(i)) prev = layer current = children[-1] layer, connection = _ann_to_snn_helper(prev, current, node_type, last=True, **kwargs) i += 1 if layer is not None or connection is not None: snn.add_layer(layer, name=str(i)) snn.add_connection(connection, source=str(i - 1), target=str(i)) return snn
def Translate_Into_Networks(input_N, Shape, Output_N, Weight): network_list = [] path = "gene/" file_list = os.listdir(path) gene_file_check = [file for file in file_list if file.endswith(".txt")] if len(gene_file_check) == 0: import startup Gene_List = Genetic.Read_Gene() for i in range(len(Gene_List)): network = Network() Decoded_List = [] Decoded_DNA_List = [] for j in range(len(Gene_List[i])): Decoded_Gene = Gene_List[i][j].split('-') if (Decoded_Gene[3] == 'F'): pass else: if Decoded_Gene[1] == '~': Decoded_List.append( [int(Decoded_Gene[0]), int(Decoded_Gene[2]), 0]) elif Decoded_Gene[1] == '!': Decoded_List.append( [int(Decoded_Gene[0]), int(Decoded_Gene[2]), 1]) elif Decoded_Gene[1] == '@': Decoded_List.append( [int(Decoded_Gene[0]), int(Decoded_Gene[2]), 2]) elif Decoded_Gene[1] == '#': Decoded_List.append( [int(Decoded_Gene[0]), int(Decoded_Gene[2]), 3]) elif Decoded_Gene[1] == '$': Decoded_List.append( [int(Decoded_Gene[0]), int(Decoded_Gene[2]), 4]) else: print("THE GENOTYPE VALUE IS UNVALID") raise ValueError Decoded_DNA_List.append(Decoded_List) Decoded_RNA_List: list = Decoded_DNA_List.copy() for decoded_dna_idx, decoded_dna in enumerate(Decoded_DNA_List): Gene_NUM = len(decoded_dna) for k in range(Gene_NUM): a = Decoded_DNA_List[decoded_dna_idx][k] for l in range(k, Gene_NUM): b = Decoded_DNA_List[decoded_dna_idx][l] if a and b == 1: if decoded_dna[k][2] == 0: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[l]) elif decoded_dna[k][2] == 1: if decoded_dna[l][2] < 1: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[k]) else: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[l]) elif decoded_dna[k][2] == 2: if decoded_dna[l][2] < 2: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[k]) else: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[l]) elif decoded_dna[k][2] == 3: if decoded_dna[l][2] < 3: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[k]) else: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[l]) elif decoded_dna[k][2] == 4: if decoded_dna[l][2] >= 4: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[l]) else: Decoded_RNA_List[decoded_dna_idx].remove( decoded_dna[k]) else: pass else: pass for Decoded_RNA in Decoded_RNA_List: layer_list = {} for m in range(len(Decoded_RNA)): for n in range(m, len(Decoded_RNA)): if Decoded_RNA[m][1] == Decoded_RNA[n][0]: if Decoded_RNA[n][2] == 0: layer_list[Decoded_RNA[m][0]] = nodes.IFNodes( n=1, traces=True) elif Decoded_RNA[n][2] == 1: layer_list[Decoded_RNA[m][0]] = nodes.LIFNodes( n=1, traces=True) elif Decoded_RNA[n][2] == 2: layer_list[Decoded_RNA[m] [0]] = nodes.McCullochPitts(n=1, traces=True) elif Decoded_RNA[n][2] == 3: layer_list[Decoded_RNA[m] [0]] = nodes.IzhikevichNodes( n=1, traces=True) elif Decoded_RNA[n][2] == 4: layer_list[Decoded_RNA[m][0]] = nodes.SRM0Nodes( n=1, traces=True) else: print("UNVALID GENO_NEURON CODE") raise ValueError elif n == len(Decoded_List) - 1: layer_list[Decoded_RNA[m][1]] = nodes.LIFNodes(n=1) for l in range(len(Decoded_RNA)): if not Decoded_RNA[l][0] in layer_list: if Decoded_RNA[l][2] == 0: layer_list[Decoded_RNA[l][0]] = nodes.IFNodes( n=1, traces=True) elif Decoded_RNA[l][2] == 1: layer_list[Decoded_RNA[l][0]] = nodes.LIFNodes( n=1, traces=True) elif Decoded_RNA[l][2] == 2: layer_list[Decoded_RNA[l][0]] = nodes.McCullochPitts( n=1, traces=True) elif Decoded_RNA[l][2] == 3: layer_list[Decoded_RNA[l][0]] = nodes.IzhikevichNodes( n=1, traces=True) elif Decoded_RNA[l][2] == 4: layer_list[Decoded_RNA[l][0]] = nodes.SRM0Nodes( n=1, traces=True) Input_Layer = nodes.Input(n=input_N, shape=Shape, traces=True) out = nodes.LIFNodes(n=Output_N, refrac=0, traces=True) network.add_layer(layer=Input_Layer, name="Input Layer") for key_l in list(layer_list.keys()): network.add_layer(layer=layer_list[key_l], name=str(key_l)) network.add_layer(layer=out, name="Output Layer") if len(layer_list.keys()) == 0: layer = nodes.LIFNodes(n=1, traces=True) network.add_layer(layer=layer, name="mid layer") inpt_connection = Connection(source=Input_Layer, target=layer, w=Weight * torch.ones(input_N)) opt_connection = Connection(source=layer, target=out, w=Weight * torch.ones(1)) network.add_connection(inpt_connection, source="Input_Layer", target="mid layer") network.add_connection(opt_connection, source="mid layer", target="Output Layer") else: for key_ic in list(layer_list.keys()): inpt_connection = Connection(source=Input_Layer, target=layer_list[key_ic], w=Weight * torch.ones(input_N)) network.add_connection(inpt_connection, source="Input_Layer", target=str(key_ic)) for key_op in list(layer_list.keys()): output_connection = Connection(source=layer_list[key_op], target=out, w=Weight * torch.ones(1), update_rule=MSTDP) network.add_connection(output_connection, source=str(key_op), target="Output Layer") for generating_protein in Decoded_RNA: mid_connection = Connection( source=layer_list[generating_protein[0]], target=layer_list[generating_protein[1]], w=Weight * torch.ones(1), update_rule=MSTDP) network.add_connection(mid_connection, source=str(generating_protein[0]), target=str(generating_protein[1])) network_list.append(network) network.save('Network/' + str(i) + '.pt') return network_list
from bindsnet.network import Network from bindsnet.network import nodes, topology, monitors from bindsnet.analysis.plotting import plot_spikes # Parameters. n_input = 100 n_output = 100 time = 1000 # Create network object. network = Network() # Create input and output groups of neurons. input_group = nodes.Input(n=n_input) # 100 input nodes. output_group = nodes.LIFNodes(n=n_output) # 500 output nodes. network.add_layer(input_group, name='input') network.add_layer(output_group, name='output') # Input -> output connection. # Unit Gaussian feed-forward weights. w = torch.randn(n_input, n_output) forward_conn = topology.Connection(input_group, output_group, w=w) # Output -> output connection. # Random, inhibitory recurrent weights. w = torch.bernoulli(torch.rand(n_output, n_output)) - torch.diag(torch.ones(n_output)) recurrent_conn = topology.Connection(output_group, output_group, w=w)