示例#1
0
def ann_to_snn(
    ann: Union[nn.Module, str],
    input_shape: Sequence[int],
    data: Optional[torch.Tensor] = None,
    percentile: float = 99.9,
    node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes,
    **kwargs,
) -> Network:
    # language=rst
    """
    Converts an artificial neural network (ANN) written as a
    ``torch.nn.Module`` into a near-equivalent spiking neural network.

    :param ann: Artificial neural network implemented in PyTorch. Accepts
        either ``torch.nn.Module`` or path to network saved using
        ``torch.save()``.
    :param input_shape: Shape of input data.
    :param data: Data to use to perform data-based weight normalization of
        shape ``[n_examples, ...]``.
    :param percentile: Percentile (in ``[0, 100]``) of activations to scale by
        in data-based normalization scheme.
    :param node_type: Class of ``Nodes`` to use in replacing
        ``torch.nn.Linear`` layers in original ANN.
    :return: Spiking neural network implemented in PyTorch.
    """
    if isinstance(ann, str):
        ann = torch.load(ann)
    else:
        ann = deepcopy(ann)

    assert isinstance(ann, nn.Module)

    if data is None:
        import warnings

        warnings.warn("Data is None. Weights will not be scaled.",
                      RuntimeWarning)
    else:
        ann = data_based_normalization(ann=ann,
                                       data=data.detach(),
                                       percentile=percentile)

    snn = Network()

    input_layer = nodes.Input(shape=input_shape)
    snn.add_layer(input_layer, name="Input")

    children = []
    for c in ann.children():
        if isinstance(c, nn.Sequential):
            for c2 in list(c.children()):
                children.append(c2)
        else:
            children.append(c)

    i = 0
    prev = input_layer
    while i < len(children) - 1:
        current, nxt = children[i:i + 2]
        layer, connection = _ann_to_snn_helper(prev, current, node_type,
                                               **kwargs)

        i += 1

        if layer is None or connection is None:
            continue

        snn.add_layer(layer, name=str(i))
        snn.add_connection(connection, source=str(i - 1), target=str(i))

        prev = layer

    current = children[-1]
    layer, connection = _ann_to_snn_helper(prev,
                                           current,
                                           node_type,
                                           last=True,
                                           **kwargs)

    i += 1

    if layer is not None or connection is not None:
        snn.add_layer(layer, name=str(i))
        snn.add_connection(connection, source=str(i - 1), target=str(i))

    return snn
示例#2
0
def Translate_Into_Networks(input_N, Shape, Output_N, Weight):
    network_list = []
    path = "gene/"
    file_list = os.listdir(path)
    gene_file_check = [file for file in file_list if file.endswith(".txt")]
    if len(gene_file_check) == 0:
        import startup
    Gene_List = Genetic.Read_Gene()
    for i in range(len(Gene_List)):
        network = Network()
        Decoded_List = []
        Decoded_DNA_List = []
        for j in range(len(Gene_List[i])):
            Decoded_Gene = Gene_List[i][j].split('-')

            if (Decoded_Gene[3] == 'F'):
                pass
            else:
                if Decoded_Gene[1] == '~':
                    Decoded_List.append(
                        [int(Decoded_Gene[0]),
                         int(Decoded_Gene[2]), 0])
                elif Decoded_Gene[1] == '!':
                    Decoded_List.append(
                        [int(Decoded_Gene[0]),
                         int(Decoded_Gene[2]), 1])
                elif Decoded_Gene[1] == '@':
                    Decoded_List.append(
                        [int(Decoded_Gene[0]),
                         int(Decoded_Gene[2]), 2])
                elif Decoded_Gene[1] == '#':
                    Decoded_List.append(
                        [int(Decoded_Gene[0]),
                         int(Decoded_Gene[2]), 3])
                elif Decoded_Gene[1] == '$':
                    Decoded_List.append(
                        [int(Decoded_Gene[0]),
                         int(Decoded_Gene[2]), 4])
                else:
                    print("THE GENOTYPE VALUE IS UNVALID")
                    raise ValueError
            Decoded_DNA_List.append(Decoded_List)

        Decoded_RNA_List: list = Decoded_DNA_List.copy()

        for decoded_dna_idx, decoded_dna in enumerate(Decoded_DNA_List):
            Gene_NUM = len(decoded_dna)
            for k in range(Gene_NUM):
                a = Decoded_DNA_List[decoded_dna_idx][k]
                for l in range(k, Gene_NUM):
                    b = Decoded_DNA_List[decoded_dna_idx][l]
                    if a and b == 1:
                        if decoded_dna[k][2] == 0:
                            Decoded_RNA_List[decoded_dna_idx].remove(
                                decoded_dna[l])

                        elif decoded_dna[k][2] == 1:
                            if decoded_dna[l][2] < 1:
                                Decoded_RNA_List[decoded_dna_idx].remove(
                                    decoded_dna[k])
                            else:
                                Decoded_RNA_List[decoded_dna_idx].remove(
                                    decoded_dna[l])

                        elif decoded_dna[k][2] == 2:
                            if decoded_dna[l][2] < 2:
                                Decoded_RNA_List[decoded_dna_idx].remove(
                                    decoded_dna[k])
                            else:
                                Decoded_RNA_List[decoded_dna_idx].remove(
                                    decoded_dna[l])

                        elif decoded_dna[k][2] == 3:
                            if decoded_dna[l][2] < 3:
                                Decoded_RNA_List[decoded_dna_idx].remove(
                                    decoded_dna[k])
                            else:
                                Decoded_RNA_List[decoded_dna_idx].remove(
                                    decoded_dna[l])

                        elif decoded_dna[k][2] == 4:
                            if decoded_dna[l][2] >= 4:
                                Decoded_RNA_List[decoded_dna_idx].remove(
                                    decoded_dna[l])
                            else:
                                Decoded_RNA_List[decoded_dna_idx].remove(
                                    decoded_dna[k])

                        else:
                            pass

                    else:
                        pass

        for Decoded_RNA in Decoded_RNA_List:

            layer_list = {}

            for m in range(len(Decoded_RNA)):

                for n in range(m, len(Decoded_RNA)):
                    if Decoded_RNA[m][1] == Decoded_RNA[n][0]:
                        if Decoded_RNA[n][2] == 0:
                            layer_list[Decoded_RNA[m][0]] = nodes.IFNodes(
                                n=1, traces=True)
                        elif Decoded_RNA[n][2] == 1:
                            layer_list[Decoded_RNA[m][0]] = nodes.LIFNodes(
                                n=1, traces=True)
                        elif Decoded_RNA[n][2] == 2:
                            layer_list[Decoded_RNA[m]
                                       [0]] = nodes.McCullochPitts(n=1,
                                                                   traces=True)
                        elif Decoded_RNA[n][2] == 3:
                            layer_list[Decoded_RNA[m]
                                       [0]] = nodes.IzhikevichNodes(
                                           n=1, traces=True)
                        elif Decoded_RNA[n][2] == 4:
                            layer_list[Decoded_RNA[m][0]] = nodes.SRM0Nodes(
                                n=1, traces=True)
                        else:
                            print("UNVALID GENO_NEURON CODE")
                            raise ValueError

                    elif n == len(Decoded_List) - 1:
                        layer_list[Decoded_RNA[m][1]] = nodes.LIFNodes(n=1)

            for l in range(len(Decoded_RNA)):
                if not Decoded_RNA[l][0] in layer_list:
                    if Decoded_RNA[l][2] == 0:
                        layer_list[Decoded_RNA[l][0]] = nodes.IFNodes(
                            n=1, traces=True)
                    elif Decoded_RNA[l][2] == 1:
                        layer_list[Decoded_RNA[l][0]] = nodes.LIFNodes(
                            n=1, traces=True)
                    elif Decoded_RNA[l][2] == 2:
                        layer_list[Decoded_RNA[l][0]] = nodes.McCullochPitts(
                            n=1, traces=True)
                    elif Decoded_RNA[l][2] == 3:
                        layer_list[Decoded_RNA[l][0]] = nodes.IzhikevichNodes(
                            n=1, traces=True)
                    elif Decoded_RNA[l][2] == 4:
                        layer_list[Decoded_RNA[l][0]] = nodes.SRM0Nodes(
                            n=1, traces=True)

        Input_Layer = nodes.Input(n=input_N, shape=Shape, traces=True)
        out = nodes.LIFNodes(n=Output_N, refrac=0, traces=True)
        network.add_layer(layer=Input_Layer, name="Input Layer")
        for key_l in list(layer_list.keys()):
            network.add_layer(layer=layer_list[key_l], name=str(key_l))
        network.add_layer(layer=out, name="Output Layer")
        if len(layer_list.keys()) == 0:
            layer = nodes.LIFNodes(n=1, traces=True)
            network.add_layer(layer=layer, name="mid layer")
            inpt_connection = Connection(source=Input_Layer,
                                         target=layer,
                                         w=Weight * torch.ones(input_N))
            opt_connection = Connection(source=layer,
                                        target=out,
                                        w=Weight * torch.ones(1))
            network.add_connection(inpt_connection,
                                   source="Input_Layer",
                                   target="mid layer")
            network.add_connection(opt_connection,
                                   source="mid layer",
                                   target="Output Layer")
        else:
            for key_ic in list(layer_list.keys()):
                inpt_connection = Connection(source=Input_Layer,
                                             target=layer_list[key_ic],
                                             w=Weight * torch.ones(input_N))
                network.add_connection(inpt_connection,
                                       source="Input_Layer",
                                       target=str(key_ic))
            for key_op in list(layer_list.keys()):
                output_connection = Connection(source=layer_list[key_op],
                                               target=out,
                                               w=Weight * torch.ones(1),
                                               update_rule=MSTDP)
                network.add_connection(output_connection,
                                       source=str(key_op),
                                       target="Output Layer")
            for generating_protein in Decoded_RNA:
                mid_connection = Connection(
                    source=layer_list[generating_protein[0]],
                    target=layer_list[generating_protein[1]],
                    w=Weight * torch.ones(1),
                    update_rule=MSTDP)
                network.add_connection(mid_connection,
                                       source=str(generating_protein[0]),
                                       target=str(generating_protein[1]))

        network_list.append(network)
        network.save('Network/' + str(i) + '.pt')
    return network_list
from bindsnet.network import Network
from bindsnet.network import nodes, topology, monitors
from bindsnet.analysis.plotting import plot_spikes


# Parameters.
n_input = 100
n_output = 100
time = 1000

# Create network object.
network = Network()

# Create input and output groups of neurons.
input_group = nodes.Input(n=n_input)  # 100 input nodes.
output_group = nodes.LIFNodes(n=n_output)  # 500 output nodes.

network.add_layer(input_group, name='input')
network.add_layer(output_group, name='output')

# Input -> output connection.
# Unit Gaussian feed-forward weights.
w = torch.randn(n_input, n_output)
forward_conn = topology.Connection(input_group, output_group, w=w)

# Output -> output connection.
# Random, inhibitory recurrent weights.
w = torch.bernoulli(torch.rand(n_output, n_output)) - torch.diag(torch.ones(n_output))
recurrent_conn = topology.Connection(output_group, output_group, w=w)