def LM_model(
    plot_parameters=False,
    plot_results=False,
    arguments=False,
    info_PN=False,
    figures=None,
    A=1.0,
    BA=0.5,
    PN_KC_weight=0.25,
    min_weight=0.0001,
    PN_thresh=-40.0,
    KC_thresh=-25.0,
    EN_thresh=-40.0,
    modification=5.0,  # best results with 1.0 for CurrentLIF (for LIF, 0.1) => augmented in order to encode the richess of the input
    stimulation_time=40  # milliseconds 
):

    begin_time = datetime.datetime.now()

    ### parameters

    dt = 1.0
    learning_time = 50  # milliseconds
    test_time = 50  # milliseconds

    if arguments == True:
        if len(sys.argv) == 1:
            print(
                "Warning - usage : python LM_model.py [name of learned image file] [name(s) of test image file(s)]"
            )
            sys.exit()

        try:
            A = int(sys.argv[-1]) * 0.1
            min_weight = int(sys.argv[-1]) * 0.0001
            last_file_index = len(sys.argv) - 2
        except ValueError:
            last_file_index = len(sys.argv)
        list_files = sys.argv[1:last_file_index]

    else:
        list_files = figures

    ### get image data

    print("Upload image data")

    input_data = {"Learning": None, "Test": {}}
    for i in range(len(list_files)):
        file_image = open(list_files[i], "r")
        image = []
        for l in file_image.readlines():
            l = list(map(lambda x: float(x), l.split()))
            image.append(l)
        file_image.close()
        image = np.array(image)
        image.shape = (1, 10, 36)

        if i == 0:
            print(list_files[i], "=> learning")
            input_data["Learning"] = {
                "Input":
                torch.from_numpy(modification * np.array([
                    image if i <= stimulation_time else np.zeros((1, 10, 36))
                    for i in range(int(learning_time / dt))
                ]))
            }
        else:
            print(list_files[i], "=> test")
            input_data["Test"][list_files[i]] = {
                "Input":
                torch.from_numpy(modification * np.array([
                    image if i <= stimulation_time else np.zeros((1, 10, 36))
                    for i in range(int(test_time / dt))
                ]))
            }

    ### network initialization based on Ardin et al's article

    print("Initialize network")

    landmark_guidance = Network(dt=dt)

    # layers
    input_layer = Input(n=360, shape=(10, 36))
    PN = Izhikevich(n=360,
                    traces=True,
                    tc_decay=10.0,
                    thresh=PN_thresh,
                    rest=-60.0,
                    C=100,
                    a=0.3,
                    b=-0.2,
                    c=-65,
                    d=8,
                    k=2)
    KC = Izhikevich(n=20000,
                    traces=True,
                    tc_decay=10.0,
                    thresh=KC_thresh,
                    rest=-85.0,
                    C=4,
                    a=0.01,
                    b=-0.3,
                    c=-65,
                    d=8,
                    k=0.035)
    EN = Izhikevich(n=1,
                    traces=True,
                    tc_decay=10.0,
                    thresh=EN_thresh,
                    rest=-60.0,
                    C=100,
                    a=0.3,
                    b=-0.2,
                    c=-65,
                    d=8,
                    k=2)
    landmark_guidance.add_layer(layer=input_layer, name="Input")
    landmark_guidance.add_layer(layer=PN, name="PN")
    landmark_guidance.add_layer(layer=KC, name="KC")
    landmark_guidance.add_layer(layer=EN, name="EN")

    # connections
    connection_weight = torch.zeros(input_layer.n, PN.n).scatter_(
        1, torch.tensor([[i, i] for i in range(PN.n)]), 1.)
    input_PN = Connection(source=input_layer, target=PN, w=connection_weight)

    connection_weight = torch.zeros(PN.n, KC.n).t()
    connection_weight = connection_weight.scatter_(
        1,
        torch.tensor([
            np.random.choice(PN.n, size=10, replace=False) for i in range(KC.n)
        ]).long(), PN_KC_weight)
    PN_KC = AllToAllConnection(source=PN,
                               target=KC,
                               w=connection_weight.t(),
                               tc_synaptic=3.0,
                               phi=0.93)

    KC_EN = AllToAllConnection(source=KC,
                               target=EN,
                               w=torch.ones(KC.n, EN.n) * 2.0,
                               tc_synaptic=8.0,
                               phi=8.0)
    print()
    print(KC_EN.w)
    print()
    landmark_guidance.add_connection(connection=input_PN,
                                     source="Input",
                                     target="PN")
    landmark_guidance.add_connection(connection=PN_KC,
                                     source="PN",
                                     target="KC")
    landmark_guidance.add_connection(connection=KC_EN,
                                     source="KC",
                                     target="EN")

    # learning rule
    KC_EN.update_rule = STDP(connection=KC_EN,
                             nu=(-A, -A),
                             tc_eligibility_trace=40.0,
                             tc_plus=15,
                             tc_minus=15,
                             tc_reward=20.0,
                             min_weight=min_weight)

    # monitors
    input_monitor = Monitor(obj=input_layer, state_vars=("s"))
    PN_monitor = Monitor(obj=PN, state_vars=("s", "v"))
    KC_monitor = Monitor(obj=KC, state_vars=("s", "v"))
    EN_monitor = Monitor(obj=EN, state_vars=("s", "v"))
    landmark_guidance.add_monitor(monitor=input_monitor, name="Input monitor")
    landmark_guidance.add_monitor(monitor=PN_monitor, name="PN monitor")
    landmark_guidance.add_monitor(monitor=KC_monitor, name="KC monitor")
    landmark_guidance.add_monitor(monitor=EN_monitor, name="EN monitor")
    print(datetime.datetime.now() - begin_time)

    ### run network : learning of 1 view
    begin_time = datetime.datetime.now()

    print("Run - learning view")

    landmark_guidance.learning = True
    landmark_guidance.run(inputs=input_data["Learning"],
                          time=learning_time,
                          reward=BA,
                          n_timesteps=test_time / dt)
    landmark_guidance.learning = False

    print()
    print(KC_EN.w)
    print()

    print("> View learned")

    if plot_parameters == True:
        plt.figure()
        plt.plot(range(learning_time + 1),
                 torch.tensor(KC_EN.update_rule.cumul_weigth))
        plt.title("Evolution of KC_EN weights for A=" + str(A) +
                  " and thresh=" + str(min_weight))
        # plt.savefig("./manual_tuning/weights_nu"+str(A)+"_thresh"+str(min_weight)+".png")

        plt.figure()
        plt.plot(range(learning_time + 1),
                 torch.tensor(KC_EN.update_rule.cumul_et))
        plt.title("Evolution of KC_EN eligibility traces for A=" + str(A) +
                  " and thresh=" + str(min_weight))
        # plt.savefig("./manual_tuning/eligibility_nu"+str(A)+"_thresh"+str(min_weight)+".png")

        plt.figure()
        plt.plot(range(learning_time),
                 torch.tensor(KC_EN.update_rule.cumul_delta_t), "b",
                 range(learning_time),
                 torch.tensor(KC_EN.update_rule.cumul_KC), "r",
                 range(learning_time),
                 torch.tensor(KC_EN.update_rule.cumul_EN), "g")
        # plt.plot(range(learning_time), torch.tensor(KC_EN.update_rule.cumul_delta_t))
        plt.title("Evolution of delta_t")

        plt.figure()
        plt.plot(range(learning_time),
                 torch.tensor(KC_EN.update_rule.cumul_STDP))
        plt.title("Evolution of STDP")

        plt.figure()
        plt.plot(range(learning_time),
                 torch.tensor(KC_EN.update_rule.cumul_pre_post))
        plt.title("Evolution of pre_post_spikes")

        plt.figure()
        plt.plot(range(learning_time), torch.tensor(PN_KC.cumul_I))
        plt.title("Evolution of I KC")

        plt.figure()
        plt.plot(range(learning_time), torch.tensor(KC_EN.cumul_I))
        plt.xlim(left=0, right=learning_time)
        plt.title("Evolution of I EN")

        plt.show(block=False)

    ### run network : test on one or more views

    print("Run - test of one or more views")
    view = {"name": None, "mean_EN": None}
    nb_spikes = []

    plt.ioff()
    for (name, data) in input_data["Test"].items():
        landmark_guidance.reset_state_variables()
        landmark_guidance.run(inputs=data,
                              time=test_time,
                              n_timesteps=test_time / dt)

        spikes = {
            "PN": PN_monitor.get("s")[-test_time:],
            "KC": KC_monitor.get("s")[-test_time:],
            "EN": EN_monitor.get("s")[-test_time:]
        }
        voltages = {
            "PN": PN_monitor.get("v")[-test_time:],
            "KC": KC_monitor.get("v")[-test_time:],
            "EN": EN_monitor.get("v")[-test_time:]
        }

        if info_PN == True:
            frequences = []
            for nodes in spikes["PN"].squeeze().t():
                frequences.append(len(torch.nonzero(nodes)))
            frequences = torch.tensor(frequences).float()
            print("Mean spikes PN :", torch.mean(frequences), "- Max :",
                  torch.max(frequences), "- Min :", torch.min(frequences))

        print(name, ":  nb spikes EN =", len(torch.nonzero(spikes["EN"])))
        nb_spikes.append(len(torch.nonzero(spikes["EN"])))

        if view["mean_EN"] == None or len(torch.nonzero(
                spikes["EN"])) < view["mean_EN"]:
            view["mean_EN"] = len(torch.nonzero(spikes["EN"]))
            view["name"] = name

        if plot_results == True:
            Pspikes = plot_spikes(spikes)
            for subplot in Pspikes[1]:
                subplot.set_xlim(left=0, right=test_time)
            Pspikes[1][1].set_ylim(bottom=0, top=KC.n)
            plt.suptitle("Results for " + name)

            # Pvoltages = plot_voltages(voltages, plot_type="line")
            # for v_subplot in Pvoltages[1]:
            #     v_subplot.set_xlim(left=0, right=test_time)
            # Pvoltages[1][2].set_ylim(bottom=min(-70, min(voltages["EN"])), top=max(-50, max(voltages["EN"])))
            # plt.suptitle("Results for " + name)

            plt.show(block=False)

    print("Most familiar view:", view["name"])

    plt.show(block=True)
    print(datetime.datetime.now() - begin_time)

    if nb_spikes[0] == nb_spikes[1] == nb_spikes[2]:
        return (view['name'], True)
    else:
        return (view["name"], False)
示例#2
0
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(6400, 1000)
        self.fc2 = nn.Linear(1000, 4)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# load ANN
dqn_network = torch.load("trained_shallow_ANN.pt", map_location=device)

# Build Spiking network.
network = Network(dt=dt).to(device)

# Layers of neurons.
inpt = Input(n=6400, traces=False)  # Input layer
middle = LIFNodes(n=1000, refrac=0, traces=True, thresh=-52.0,
                  rest=-65.0)  # Hidden layer
readout = LIFNodes(n=4, refrac=0, traces=True, thresh=-52.0,
                   rest=-65.0)  # Readout layer
layers = {"X": inpt, "M": middle, "R": readout}

# Set the connections between layers with the values set by the ANN
# Input -> hidden.
inpt_middle = Connection(
    source=layers["X"],
    target=layers["M"],
    w=torch.transpose(dqn_network.fc1.weight, 0, 1) * layer1scale,
示例#3
0
def ann_to_snn(
    ann: Union[nn.Module, str],
    input_shape: Sequence[int],
    data: Optional[torch.Tensor] = None,
    percentile: float = 99.9,
    node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes,
    **kwargs,
) -> Network:
    # language=rst
    """
    Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural
    network.

    :param ann: Artificial neural network implemented in PyTorch. Accepts either ``torch.nn.Module`` or path to network
                saved using ``torch.save()``.
    :param input_shape: Shape of input data.
    :param data: Data to use to perform data-based weight normalization of shape ``[n_examples, ...]``.
    :param percentile: Percentile (in ``[0, 100]``) of activations to scale by in data-based normalization scheme.
    :param node_type: Class of ``Nodes`` to use in replacing ``torch.nn.Linear`` layers in original ANN.
    :return: Spiking neural network implemented in PyTorch.
    """
    if isinstance(ann, str):
        ann = torch.load(ann)
    else:
        ann = deepcopy(ann)

    assert isinstance(ann, nn.Module)

    if data is None:
        import warnings

        warnings.warn("Data is None. Weights will not be scaled.",
                      RuntimeWarning)
    else:
        ann = data_based_normalization(ann=ann,
                                       data=data.detach(),
                                       percentile=percentile)

    snn = Network()

    input_layer = nodes.RealInput(shape=input_shape)
    snn.add_layer(input_layer, name="Input")

    children = []
    for c in ann.children():
        if isinstance(c, nn.Sequential):
            for c2 in list(c.children()):
                children.append(c2)
        else:
            children.append(c)

    i = 0
    prev = input_layer
    while i < len(children) - 1:
        current, nxt = children[i:i + 2]
        layer, connection = _ann_to_snn_helper(prev, current, node_type,
                                               **kwargs)

        i += 1

        if layer is None or connection is None:
            continue

        snn.add_layer(layer, name=str(i))
        snn.add_connection(connection, source=str(i - 1), target=str(i))

        prev = layer

    current = children[-1]
    layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs)

    i += 1

    if layer is not None or connection is not None:
        snn.add_layer(layer, name=str(i))
        snn.add_connection(connection, source=str(i - 1), target=str(i))

    return snn
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=250, time=50, lr=1e-2, lr_decay=0.99, dt=1,
         theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, train=True, plot=False,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, time, lr, lr_decay,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, time, lr, lr_decay,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    if train:
        n_examples = n_train
    else:
        n_examples = n_test

    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = RealInput(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(
            n=n_neurons, traces=True, rest=0, reset=0, thresh=1, refrac=0,
            decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay
        )
        network.add_layer(output_layer, name='Y')

        w = 0.3 * torch.rand(784, n_neurons)
        input_connection = Connection(
            source=network.layers['X'], target=network.layers['Y'], w=w, update_rule=PostPre,
            nu=[0, lr], wmin=0, wmax=1, norm=78.4
        )
        network.add_connection(input_connection, source='X', target='Y')

        w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons)))
        recurrent_connection = Connection(
            source=network.layers['Y'], target=network.layers['Y'], w=w, wmin=-inhib, wmax=0
        )
        network.add_connection(recurrent_connection, source='Y', target='Y')

    else:
        path = os.path.join('..', '..', 'params', data, model)
        network = load_network(os.path.join(path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
        )
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load Fashion-MNIST data.
    dataset = FashionMNIST(path=os.path.join('..', '..', 'data', 'FashionMNIST'), download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images = images / 255

    # if train:
    #     for i in range(n_neurons):
    #         network.connections['X', 'Y'].w[:, i] = images[i] + images[i].mean() * torch.randn(784)

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    else:
        path = os.path.join('..', '..', 'params', data, model)
        path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(open(path, 'rb'))

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}

    if train:
        best_accuracy = 0

    spikes = {}

    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0 and train:
            network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i % len(images)]

            # Update and print accuracy evaluations.
            curves, predictions = update_curves(
                curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments,
                proportions=proportions, ngram_scores=ngram_scores, n=2
            )
            print_results(curves)

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print('New best accuracy! Saving network parameters to disk.')

                    # Save network to disk.
                    path = os.path.join('..', '..', 'params', data, model)
                    if not os.path.isdir(path):
                        os.makedirs(path)

                    network.save(os.path.join(path, model_name + '.pt'))
                    path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb'))

                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(spike_record, current_labels, n_classes, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % n_examples].repeat([time, 1])
        inpts = {'X': image}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 5 and retries < 3:
            retries += 1
            image *= 2
            inpts = {'X': image}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            _input = images[i % n_examples].view(28, 28)
            reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            square_assignments = get_square_assignments(assignments, n_sqrt)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im, wmax=0.25)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i % len(images)]

    # Update and print accuracy evaluations.
    curves, predictions = update_curves(
        curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments,
        proportions=proportions, ngram_scores=ngram_scores, n=2
    )
    print_results(curves)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            if train:
                path = os.path.join('..', '..', 'params', data, model)
                if not os.path.isdir(path):
                    os.makedirs(path)

                network.save(os.path.join(path, model_name + '.pt'))
                path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt')
                torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    path = os.path.join('..', '..', 'curves', data, model)
    if not os.path.isdir(path):
        os.makedirs(path)

    if train:
        to_write = ['train'] + params
    else:
        to_write = ['test'] + params

    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'

    torch.save((curves, update_interval, n_examples), open(os.path.join(path, f), 'wb'))

    # Save results to disk.
    path = os.path.join('..', '..', 'results', data, model)
    if not os.path.isdir(path):
        os.makedirs(path)

    results = [
        np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']),
        np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram'])
    ]

    if train:
        to_write = params + results
    else:
        to_write = test_params + results

    to_write = [str(x) for x in to_write]

    if train:
        name = 'train.csv'
    else:
        name = 'test.csv'

    if not os.path.isfile(os.path.join(path, name)):
        with open(os.path.join(path, name), 'w') as f:
            if train:
                f.write('random_seed,n_neurons,n_train,inhib,time,lr,lr_decay,theta_plus,theta_decay,'
                        'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                        'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n')
            else:
                f.write('random_seed,n_neurons,n_train,n_test,inhib,time,lr,lr_decay,theta_plus,theta_decay,'
                        'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                        'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n')

    with open(os.path.join(path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')
示例#5
0
plot = args.plot
gpu = args.gpu
device_id = args.device_id

np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)

# Sets up Gpu use
if gpu and torch.cuda.is_available():
    torch.cuda.set_device(device_id)
    # torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    torch.manual_seed(seed)

network = Network(dt=dt)
inpt = Input(784, shape=(1, 28, 28))
network.add_layer(inpt, name="I")
output = LIFNodes(n_neurons,
                  thresh=-52 + np.random.randn(n_neurons).astype(float))
network.add_layer(output, name="O")
C1 = Connection(source=inpt,
                target=output,
                w=0.5 * torch.randn(inpt.n, output.n))
C2 = Connection(source=output,
                target=output,
                w=0.5 * torch.randn(output.n, output.n))

network.add_connection(C1, source="I", target="O")
network.add_connection(C2, source="O", target="O")
示例#6
0
def ann_to_snn(ann: Union[nn.Module, str], input_shape: Sequence[int], data: Optional[torch.Tensor] = None,
               percentile: float = 99.9) -> Network:
    # language=rst
    """
    Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural
    network.

    :param ann: Artificial neural network implemented in PyTorch. Accepts either ``torch.nn.Module`` or path to network
                saved using ``torch.save()``.
    :param input_shape: Shape of input data.
    :param data: Data to use to perform data-based weight normalization of shape ``[n_examples, ...]``.
    :param percentile: Percentile (in ``[0, 100]``) of activations to scale by in data-based normalization scheme.
    :return: Spiking neural network implemented in PyTorch.
    """
    if isinstance(ann, str):
        ann = torch.load(ann)

    assert isinstance(ann, nn.Module)

    if data is not None:
        print()
        print('Example data provided. Performing data-based normalization...')

        t0 = t()
        ann = data_based_normalization(
            ann=ann, data=data.detach(), percentile=percentile
        )

        print(f'Elapsed: {t() - t0:.4f}')

    snn = Network()

    input_layer = nodes.RealInput(shape=input_shape)
    snn.add_layer(input_layer, name='Input')

    children = []
    for c in ann.children():
        if isinstance(c, nn.Sequential):
            for c2 in list(c.children()):
                children.append(c2)
        else:
            children.append(c)

    i = 0
    prev = input_layer
    while i < len(children) - 1:
        current, nxt = children[i:i + 2]
        layer, connection = _ann_to_snn_helper(prev, current, nxt)

        i += 1

        if layer is None or connection is None:
            continue

        snn.add_layer(layer, name=str(i))
        snn.add_connection(connection, source=str(i - 1), target=str(i))

        prev = layer

    current = children[-1]
    layer, connection = _ann_to_snn_helper(prev, current, None)

    i += 1

    if layer is not None or connection is not None:
        snn.add_layer(layer, name=str(i))
        snn.add_connection(connection, source=str(i - 1), target=str(i))

    return snn
示例#7
0
def main(seed=0,
         n_train=60000,
         n_test=10000,
         time=50,
         lr=0.01,
         lr_decay=0.95,
         update_interval=500,
         max_prob=1.0,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [seed, n_train, time, lr, lr_decay, update_interval, max_prob]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, n_train, n_test, time, lr, lr_decay, update_interval,
            max_prob
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    criterion = torch.nn.CrossEntropyLoss(
    )  # Loss function on output firing rates.
    n_examples = n_train if train else n_test

    if train:
        # Network building.
        network = Network()

        # Groups of neurons.
        input_layer = RealInput(n=784, sum_input=True)
        output_layer = IFNodes(n=10, sum_input=True)
        bias = RealInput(n=1, sum_input=True)
        network.add_layer(input_layer, name='X')
        network.add_layer(output_layer, name='Y')
        network.add_layer(bias, name='Y_b')

        # Connections between groups of neurons.
        input_connection = Connection(source=input_layer,
                                      target=output_layer,
                                      norm=150,
                                      wmin=-1,
                                      wmax=1)
        bias_connection = Connection(source=bias, target=output_layer)
        network.add_connection(input_connection, source='X', target='Y')
        network.add_connection(bias_connection, source='Y_b', target='Y')

        # State variable monitoring.
        for l in network.layers:
            m = Monitor(network.layers[l], state_vars=['s'], time=time)
            network.add_monitor(m, name=l)
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True, shuffle=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images, labels = images.view(-1, 784) / 255, labels

    grads = {}
    accuracies = []
    predictions = []
    ground_truth = []
    best = -np.inf
    spike_ims, spike_axes, weights_im = None, None, None
    losses = torch.zeros(update_interval)
    correct = torch.zeros(update_interval)

    # Run training.
    start = t()
    for i in range(n_examples):
        label = torch.Tensor([labels[i % len(labels)]]).long()
        image = images[i % len(labels)]

        # Run simulation for single datum.
        inpts = {'X': image.repeat(time, 1), 'Y_b': torch.ones(time, 1)}
        network.run(inpts=inpts, time=time)

        # Retrieve spikes and summed inputs from both layers.
        spikes = {
            l: network.monitors[l].get('s')
            for l in network.layers if '_b' not in l
        }
        summed_inputs = {l: network.layers[l].summed for l in network.layers}

        # Compute softmax of output spiking activity and get predicted label.
        output = summed_inputs['Y'].softmax(0).view(1, -1)
        predicted = output.argmax(1).item()
        correct[i % update_interval] = int(predicted == label[0].item())
        predictions.append(predicted)
        ground_truth.append(label)

        # Compute cross-entropy loss between output and true label.
        losses[i % update_interval] = criterion(output, label)

        if train:
            # Compute gradient of the loss WRT average firing rates.
            grads['dl/df'] = summed_inputs['Y'].softmax(0)
            grads['dl/df'][label] -= 1

            # Compute gradient of the summed voltages WRT connection weights.
            # This is an approximation; the summed voltages are not a
            # smooth function of the connection weights.
            grads['dl/dw'] = torch.ger(summed_inputs['X'], grads['dl/df'])
            grads['dl/db'] = grads['dl/df']

            # Do stochastic gradient descent calculation.
            network.connections['X', 'Y'].w -= lr * grads['dl/dw']
            network.connections['Y_b', 'Y'].w -= lr * grads['dl/db']

        if i > 0 and i % update_interval == 0:
            accuracies.append(correct.mean() * 100)

            if train:
                if accuracies[-1] > best:
                    print()
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    best = accuracies[-1]

            print()
            print(f'Progress: {i} / {n_examples} ({t() - start:.3f} seconds)')
            print(f'Average cross-entropy loss: {losses.mean():.3f}')
            print(f'Last accuracy: {accuracies[-1]:.3f}')
            print(f'Average accuracy: {np.mean(accuracies):.3f}')

            # Decay learning rate.
            lr *= lr_decay

            if train:
                print(f'Best accuracy: {best:.3f}')
                print(f'Current learning rate: {lr:.3f}')

            start = t()

        if plot:
            w = network.connections['X', 'Y'].w
            weights = [w[:, i].view(28, 28) for i in range(10)]
            w = torch.zeros(5 * 28, 2 * 28)
            for i in range(5):
                for j in range(2):
                    w[i * 28:(i + 1) * 28,
                      j * 28:(j + 1) * 28] = weights[i + j * 5]

            spike_ims, spike_axes = plot_spikes(spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(w, im=weights_im, wmin=-1, wmax=1)

            plt.pause(1e-1)

        network.reset_()  # Reset state variables.

    accuracies.append(correct.mean() * 100)

    if train:
        lr *= lr_decay
        for c in network.connections:
            network.connections[c].update_rule.weight_decay *= lr_decay

        if accuracies[-1] > best:
            print()
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            best = accuracies[-1]

    print()
    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.3f} seconds)')
    print(f'Average cross-entropy loss: {losses.mean():.3f}')
    print(f'Last accuracy: {accuracies[-1]:.3f}')
    print(f'Average accuracy: {np.mean(accuracies):.3f}')

    if train:
        print(f'Best accuracy: {best:.3f}')

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print(f'Average accuracy: {np.mean(accuracies):.3f}')

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((accuracies, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    results = [np.mean(accuracies), np.max(accuracies)]
    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'seed,n_train,time,lr,lr_decay,update_interval,max_prob,mean_accuracy,max_accuracy\n'
                )
            else:
                f.write(
                    'seed,n_train,n_test,time,lr,lr_decay,update_interval,max_prob,mean_accuracy,max_accuracy\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    # Compute confusion matrices and save them to disk.
    confusion = confusion_matrix(ground_truth, predictions)

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusion, os.path.join(confusion_path, f))
def main(seed=0, n_train=60000, n_test=10000, inhib=250, kernel_size=(16,), stride=(2,), n_filters=25, n_output=100,
         time=100, crop=0, lr=1e-2, lr_decay=0.99, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, norm=0.2,
         progress_interval=10, update_interval=250, train=True, plot=False, gpu=False):

    assert n_train % update_interval == 0, 'No. examples must be divisible by update_interval'

    params = [
        seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, inhib, time, dt,
        theta_plus, theta_decay, intensity, norm, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, n_test, inhib, time, dt,
            theta_plus, theta_decay, intensity, norm, progress_interval, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    side_length = 28 - crop * 2
    n_inpt = side_length ** 2
    n_examples = n_train if train else n_test
    n_classes = 10

    # Build network.
    if train:
        network = Network()

        conv_size = (
            int((side_length - kernel_size) / stride) + 1,
            int((side_length - kernel_size) / stride) + 1
        )

        input_layer = Input(n=n_inpt, traces=True, trace_tc=5e-2)

        output_layer = DiehlAndCookNodes(
            n=n_filters * conv_size[0] * conv_size[1], traces=True, rest=0, reset=0,
            thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus,
            theta_decay=theta_decay
        )
        input_output_conn = LocallyConnectedConnection(
            input_layer, output_layer, kernel_size=kernel_size, stride=stride, n_filters=n_filters,
            nu=[0, lr], update_rule=WeightDependentPostPre, wmin=0, wmax=1,
            norm=norm, input_shape=(side_length, side_length)
        )

        w = torch.zeros(n_filters, *conv_size, n_filters, *conv_size)
        for fltr1 in range(n_filters):
            for fltr2 in range(n_filters):
                if fltr1 != fltr2:
                    for i in range(conv_size[0]):
                        for j in range(conv_size[1]):
                            w[fltr1, i, j, fltr2, i, j] = -inhib

        w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1])
        recurrent_conn = Connection(output_layer, output_layer, w=w)

        network.add_layer(input_layer, name='X')
        network.add_layer(output_layer, name='Y')
        network.add_connection(input_output_conn, source='X', target='Y')
        network.add_connection(recurrent_conn, source='Y', target='Y')

        output_layer = LIFNodes(
            n=n_output, traces=True, rest=0, reset=0, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2
        )

        hidden_output_connection = Connection(
            network.layers['Y'], output_layer, nu=[0, 5 * lr],
            update_rule=WeightDependentPostPre, wmin=0,
            wmax=1, norm=norm * n_output
        )

        w = -inhib * (torch.ones(n_output, n_output) - torch.diag(torch.ones(n_output)))
        output_recurrent_connection = Connection(
            output_layer, output_layer, w=w, update_rule=NoOp, wmin=-inhib, wmax=0
        )

        network.add_layer(output_layer, name='Z')
        network.add_connection(hidden_output_connection, source='Y', target='Z')
        network.add_connection(output_recurrent_connection, source='Z', target='Z')
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))

        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
        )

        network.layers['Y'].theta = 0
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

        # del network.connections['Y', 'Y']

        network.connections['Y', 'Z'].update_rule = NoOp(
            connection=network.connections['Y', 'Z'], nu=0
        )

        # network.layers['Z'].theta = 0
        # network.layers['Z'].theta_decay = 0
        # network.layers['Z'].theta_plus = 0

        # del network.connections['Z', 'Z']

    conv_size = network.connections['X', 'Y'].conv_size
    locations = network.connections['X', 'Y'].locations
    conv_prod = int(np.prod(conv_size))
    n_neurons = n_filters * conv_prod

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity
    images = images[:, crop:-crop, crop:-crop].contiguous().view(-1, side_length ** 2)

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time)
        network.add_monitor(spikes[layer], name=f'{layer}_spikes')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    weights2_im = None

    unclamps = {}
    per_class = int(n_output / n_classes)
    for label in range(n_classes):
        unclamp = torch.ones(n_output).byte()
        unclamp[label * per_class: (label + 1) * per_class] = 0
        unclamps[label] = unclamp

    predictions = torch.zeros(n_examples)
    corrects = torch.zeros(n_examples)

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.save(os.path.join(params_path, model_name + '.pt'))
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Get next input sample.
        image = images[i % len(images)]
        label = labels[i % len(images)].item()
        sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.7)
        inpts = {'X': sample}

        # Run the network on the input.
        if train:
            network.run(inpts=inpts, time=time, unclamp={'Z': unclamps[label]})
        else:
            network.run(inpts=inpts, time=time)

        if not train:
            retries = 0
            while spikes['Z'].get('s').sum() < 5 and retries < 3:
                retries += 1
                sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.7 + 0.1 * retries)
                inpts = {'X': sample}

                if train:
                    network.run(inpts=inpts, time=time, unclamp={'Z': unclamps[label]})
                else:
                    network.run(inpts=inpts, time=time)

        output = spikes['Z'].get('s')
        summed_neurons = output.sum(dim=1).view(per_class, n_classes)
        summed_classes = summed_neurons.sum(dim=1)
        prediction = torch.argmax(summed_classes).item()
        correct = prediction == label

        predictions[i] = prediction
        corrects[i] = int(correct)

        # Optionally plot various simulation information.
        if plot:
            _spikes = {
                'X': spikes['X'].get('s').view(side_length ** 2, time),
                'Y': spikes['Y'].get('s').view(n_neurons, time),
                'Z': spikes['Z'].get('s').view(n_output, time)
            }

            spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes)
            weights_im = plot_locally_connected_weights(
                network.connections['X', 'Y'].w, n_filters, kernel_size,
                conv_size, locations, side_length, im=weights_im
            )

            n_sqrt = int(np.ceil(np.sqrt(n_output)))
            side = int(np.ceil(np.sqrt(network.layers['Y'].n)))
            w = network.connections['Y', 'Z'].w
            w = get_square_weights(w, n_sqrt=n_sqrt, side=side)

            weights2_im = plot_weights(
                w, im=weights2_im, wmax=1
            )

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        network.save(os.path.join(params_path, model_name + '.pt'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    accuracy = torch.mean(corrects).item() * 100

    print(f'\nAccuracy: {accuracy}\n')

    to_write = params + [accuracy] if train else test_params + [accuracy]
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,inhib,time,timestep,theta_plus,'
                    'theta_decay,intensity,norm,progress_interval,accuracy\n'
                )
            else:
                f.write(
                    'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,n_test,inhib,time,timestep,'
                    'theta_plus,theta_decay,intensity,norm,progress_interval,update_interval,accuracy\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat([labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusion = confusion_matrix(labels, predictions)

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusion, os.path.join(confusion_path, f))
示例#9
0
def main(seed=0,
         n_train=60000,
         n_test=10000,
         kernel_size=16,
         stride=4,
         n_filters=25,
         padding=0,
         inhib=500,
         lr=0.01,
         lr_decay=0.99,
         time=50,
         dt=1,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         train=True,
         plot=False,
         gpu=False):

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    if not train:
        update_interval = n_test

    if kernel_size == 32:
        conv_size = 1
    else:
        conv_size = int((32 - kernel_size + 2 * padding) / stride) + 1

    per_class = int((n_filters * conv_size * conv_size) / 10)

    # Build network.
    network = Network()
    input_layer = Input(n=1024, shape=(1, 1, 32, 32), traces=True)

    conv_layer = DiehlAndCookNodes(n=n_filters * conv_size * conv_size,
                                   shape=(1, n_filters, conv_size, conv_size),
                                   traces=True)

    conv_conn = Conv2dConnection(input_layer,
                                 conv_layer,
                                 kernel_size=kernel_size,
                                 stride=stride,
                                 update_rule=PostPre,
                                 norm=0.4 * kernel_size**2,
                                 nu=[0, lr],
                                 wmin=0,
                                 wmax=1)

    w = -inhib * torch.ones(n_filters, conv_size, conv_size, n_filters,
                            conv_size, conv_size)
    for f in range(n_filters):
        for i in range(conv_size):
            for j in range(conv_size):
                w[f, i, j, f, i, j] = 0

    w = w.view(n_filters * conv_size**2, n_filters * conv_size**2)
    recurrent_conn = Connection(conv_layer, conv_layer, w=w)

    network.add_layer(input_layer, name='X')
    network.add_layer(conv_layer, name='Y')
    network.add_connection(conv_conn, source='X', target='Y')
    network.add_connection(recurrent_conn, source='Y', target='Y')

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load CIFAR-10 data.
    dataset = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'),
                      download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity
    images = images.mean(-1)

    # Lazily encode data as Poisson spike trains.
    data_loader = poisson_loader(data=images, time=time, dt=dt)

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    voltages = {}
    for layer in set(network.layers) - {'X'}:
        voltages[layer] = Monitor(network.layers[layer],
                                  state_vars=['v'],
                                  time=time)
        network.add_monitor(voltages[layer], name='%s_voltages' % layer)

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    voltage_ims = None
    voltage_axes = None

    # Train the network.
    print('Begin training.\n')
    start = t()

    for i in range(n_train):
        if i % progress_interval == 0:
            print('Progress: %d / %d (%.4f seconds)' %
                  (i, n_train, t() - start))
            start = t()

            if train and i > 0:
                network.connections['X', 'Y'].nu[1] *= lr_decay

        # Get next input sample.
        sample = next(data_loader).unsqueeze(1).unsqueeze(1)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        # Optionally plot various simulation information.
        if plot:
            # inpt = inpts['X'].view(time, 1024).sum(0).view(32, 32)

            weights1 = conv_conn.w
            _spikes = {
                'X': spikes['X'].get('s').view(32**2, time),
                'Y': spikes['Y'].get('s').view(n_filters * conv_size**2, time)
            }
            _voltages = {
                'Y': voltages['Y'].get('v').view(n_filters * conv_size**2,
                                                 time)
            }

            # inpt_axes, inpt_ims = plot_input(
            #     images[i].view(32, 32), inpt, label=labels[i], axes=inpt_axes, ims=inpt_ims
            # )
            # voltage_ims, voltage_axes = plot_voltages(_voltages, ims=voltage_ims, axes=voltage_axes)

            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_conv2d_weights(weights1, im=weights_im)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print('Progress: %d / %d (%.4f seconds)\n' %
          (n_train, n_train, t() - start))
    print('Training complete.\n')
示例#10
0
def main(seed=0,
         time=250,
         n_snn_episodes=1,
         epsilon=0.05,
         plot=False,
         parameter1=1.0,
         parameter2=1.0,
         parameter3=1.0,
         parameter4=1.0,
         parameter5=1.0):

    np.random.seed(seed)

    parameters = [parameter1, parameter2, parameter3, parameter4, parameter5]

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    print()
    print('Loading the trained ANN...')
    print()

    ANN = Net()
    ANN.load_state_dict(torch.load('../../params/pytorch_breakout_dqn.pt'))

    environment = make_atari('BreakoutNoFrameskip-v4')
    environment = wrap_deepmind(environment,
                                frame_stack=True,
                                scale=False,
                                clip_rewards=False,
                                episode_life=False)

    print('Converting ANN to SNN...')
    # Do ANN to SNN conversion.

    # SNN = ann_to_snn(ANN, input_shape=(1, 4, 84, 84), data=states / 255.0, percentile=percentile, node_type=LIFNodes, decay=1e-2 / 13.0, rest=0.0)

    SNN = Network()

    input_layer = nodes.RealInput(shape=(1, 4, 84, 84))
    SNN.add_layer(input_layer, name='Input')

    children = []
    for c in ANN.children():
        if isinstance(c, nn.Sequential):
            for c2 in list(c.children()):
                children.append(c2)
        else:
            children.append(c)

    i = 0
    prev = input_layer
    scale_index = 0
    while i < len(children) - 1:
        current, nxt = children[i:i + 2]
        layer, connection = _ann_to_snn_helper(prev,
                                               current,
                                               scale=parameters[scale_index])

        i += 1

        if layer is None or connection is None:
            continue

        SNN.add_layer(layer, name=str(i))
        SNN.add_connection(connection, source=str(i - 1), target=str(i))

        prev = layer

        if isinstance(current, nn.Linear) or isinstance(current, nn.Conv2d):
            scale_index += 1

    current = children[-1]
    layer, connection = _ann_to_snn_helper(prev,
                                           current,
                                           scale=parameters[scale_index])

    i += 1

    if layer is not None or connection is not None:
        SNN.add_layer(layer, name=str(i))
        SNN.add_connection(connection, source=str(i - 1), target=str(i))

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(Monitor(SNN.layers[l],
                                    state_vars=['s', 'v'],
                                    time=time),
                            name=l)
        else:
            SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s'],
                                    time=time),
                            name=l)

    spike_ims = None
    spike_axes = None
    inpt_ims = None
    inpt_axes = None
    voltage_ims = None
    voltage_axes = None

    rewards = np.zeros(n_snn_episodes)
    total_t = 0

    print()
    print('Testing SNN on Atari Breakout game...')
    print()

    # Test SNN on Atari Breakout.
    for i in range(n_snn_episodes):
        state = torch.tensor(
            environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2)

        start = t_()
        for t in itertools.count():
            print(f'Timestep {t} (elapsed {t_() - start:.2f})')
            start = t_()

            sys.stdout.flush()

            state = state.repeat(time, 1, 1, 1, 1)

            inpts = {'Input': state.float() / 255.0}

            SNN.run(inpts=inpts, time=time)

            spikes = {
                layer: SNN.monitors[layer].get('s')
                for layer in SNN.monitors
            }
            voltages = {
                layer: SNN.monitors[layer].get('v')
                for layer in SNN.monitors if not layer == 'Input'
            }
            probs, best_action = policy(spikes['12'].sum(1), epsilon)
            action = np.random.choice(np.arange(len(probs)), p=probs)

            next_state, reward, done, info = environment.step(action)
            next_state = torch.tensor(next_state).unsqueeze(0).permute(
                0, 3, 1, 2)

            rewards[i] += reward
            total_t += 1

            SNN.reset_()

            if plot:
                # Get voltage recording.
                inpt = state.view(time, 4, 84, 84).sum(0).sum(0).view(84, 84)
                spike_ims, spike_axes = plot_spikes(
                    {layer: spikes[layer]
                     for layer in spikes},
                    ims=spike_ims,
                    axes=spike_axes)
                voltage_ims, voltage_axes = plot_voltages(
                    {
                        layer: voltages[layer].view(time, -1)
                        for layer in voltages
                    },
                    ims=voltage_ims,
                    axes=voltage_axes)
                inpt_axes, inpt_ims = plot_input(inpt,
                                                 inpt,
                                                 ims=inpt_ims,
                                                 axes=inpt_axes)
                plt.pause(1e-8)

            if done:
                print(
                    f'Step {t} ({total_t}) @ Episode {i + 1} / {n_snn_episodes}'
                )
                print(f'Episode Reward: {rewards[i]}')
                print()

                break

            state = next_state

    model_name = '_'.join([
        str(x) for x in
        [seed, parameter1, parameter2, parameter3, parameter4, parameter5]
    ])
    columns = [
        'seed', 'time', 'n_snn_episodes', 'avg. reward', 'parameter1',
        'parameter2', 'parameter3', 'parameter4', 'parameter5'
    ]
    data = [[
        seed, time, n_snn_episodes,
        np.mean(rewards), parameter1, parameter2, parameter3, parameter4,
        parameter5
    ]]

    path = os.path.join(results_path, 'results.csv')
    if not os.path.isfile(path):
        df = pd.DataFrame(data=data, index=[model_name], columns=columns)
    else:
        df = pd.read_csv(path, index_col=0)

        if model_name not in df.index:
            df = df.append(
                pd.DataFrame(data=data, index=[model_name], columns=columns))
        else:
            df.loc[model_name] = data[0]

    df.to_csv(path, index=True)

    torch.save(rewards,
               os.path.join(results_path, f'{model_name}_episode_rewards.pt'))
示例#11
0
import matplotlib.pyplot as plt
from bindsnet.network import Network
from bindsnet.network.nodes import Input
from bindsnet.network.monitors import Monitor
from bindsnet.analysis.plotting import plot_spikes, plot_voltages

from utils import *

time = 1000
dt = 1

IMG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                        '../images/car_side/image_0001.jpg')
img = read_img(IMG_PATH)
img = apply_gabor(img, (11, 11), 4.0, 0, 10, 0.5)

net = Network(dt=dt)

src = Input(shape=img.shape, traces=True)
net.add_layer(layer=src, name="SRC")

src_monitor = Monitor(obj=src, state_vars=("s", ))
net.add_monitor(monitor=src_monitor, name="SRC")

inputs = {"SRC": rank_order_encode(img, time, dt)}
net.run(inputs=inputs, time=time, decay=0.0)

spikes = {"SRC": src_monitor.get("s")}
plt.ioff()
plot_spikes(spikes)
plt.show()
示例#12
0
def toLIF(network: Network):  # was not used for final implementation
    new_network = Network(dt=1, learning=True)
    input_layer = Input(n=network.X.n,
                        shape=network.X.shape,
                        traces=True,
                        tc_trace=network.X.tc_trace.item())
    exc_layer = LIFNodes(
        n=network.Ae.n,
        traces=True,
        rest=network.Ai.rest.item(),
        reset=network.Ai.reset.item(),
        thresh=network.Ai.thresh.item(),
        refrac=network.Ai.refrac.item(),
        tc_decay=network.Ai.tc_decay.item(),
    )
    inh_layer = LIFNodes(
        n=network.Ai.n,
        traces=False,
        rest=network.Ai.rest.item(),
        reset=network.Ai.reset.item(),
        thresh=network.Ai.thresh.item(),
        tc_decay=network.Ai.tc_decay.item(),
        refrac=network.Ai.refrac.item(),
    )

    # Connections
    w = network.X_to_Ae.w
    input_exc_conn = Connection(
        source=input_layer,
        target=exc_layer,
        w=w,
        update_rule=PostPre,
        nu=network.X_to_Ae.nu,
        reduction=network.X_to_Ae.reduction,
        wmin=network.X_to_Ae.wmin,
        wmax=network.X_to_Ae.wmax,
        norm=network.X_to_Ae.norm * 1,
    )
    w = network.Ae_to_Ai.w
    exc_inh_conn = Connection(source=exc_layer,
                              target=inh_layer,
                              w=w,
                              wmin=network.Ae_to_Ai.wmin,
                              wmax=network.Ae_to_Ai.wmax)
    w = network.Ai_to_Ae.w

    inh_exc_conn = Connection(source=inh_layer,
                              target=exc_layer,
                              w=w,
                              wmin=network.Ai_to_Ae.wmin,
                              wmax=network.Ai_to_Ae.wmax)

    # Add to network
    new_network.add_layer(input_layer, name="X")
    new_network.add_layer(exc_layer, name="Ae")
    new_network.add_layer(inh_layer, name="Ai")
    new_network.add_connection(input_exc_conn, source="X", target="Ae")
    new_network.add_connection(exc_inh_conn, source="Ae", target="Ai")
    new_network.add_connection(inh_exc_conn, source="Ai", target="Ae")

    exc_voltage_monitor = Monitor(new_network.layers["Ae"], ["v"], time=500)
    inh_voltage_monitor = Monitor(new_network.layers["Ai"], ["v"], time=500)
    new_network.add_monitor(exc_voltage_monitor, name="exc_voltage")
    new_network.add_monitor(inh_voltage_monitor, name="inh_voltage")

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(new_network.layers[layer],
                                state_vars=["s"],
                                time=time)
        new_network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    return new_network
示例#13
0
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=100,
         lr=0.01,
         lr_decay=1,
         time=350,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr_decay, time, dt, theta_plus,
        theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = RealInput(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(n=n_classes,
                                         rest=0,
                                         reset=1,
                                         thresh=1,
                                         decay=1e-2,
                                         theta_plus=theta_plus,
                                         theta_decay=theta_decay,
                                         traces=True,
                                         trace_tc=5e-2)
        network.add_layer(output_layer, name='Y')

        w = torch.rand(784, n_classes)
        input_connection = Connection(source=input_layer,
                                      target=output_layer,
                                      w=w,
                                      update_rule=MSTDPET,
                                      nu=lr,
                                      wmin=0,
                                      wmax=1,
                                      norm=78.4,
                                      tc_e_trace=0.1)
        network.add_connection(input_connection, source='X', target='Y')

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    environment = MNISTEnvironment(dataset=MNIST(path=data_path,
                                                 download=True),
                                   train=train,
                                   time=time)

    # Create pipeline.
    pipeline = Pipeline(network=network,
                        environment=environment,
                        encoding=repeat,
                        action_function=select_spiked,
                        output='Y',
                        reward_delay=None)

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=('s', ),
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    network.add_monitor(
        Monitor(network.connections['X', 'Y'].update_rule,
                state_vars=('e_trace', ),
                time=time), 'X_Y_e_trace')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    elig_axes = None
    elig_ims = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

            if i > 0 and train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Run the network on the input.
        for j in range(time):
            pipeline.step(a_plus=1, a_minus=0)

        if plot:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            w = network.connections['X', 'Y'].w
            square_weights = get_square_weights(w.view(784, n_classes), 4, 28)

            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            elig_ims, elig_axes = plot_voltages(
                {
                    'Y':
                    network.monitors['X_Y_e_trace'].get('e_trace').view(
                        -1, time)[1500:2000]
                },
                plot_type='line',
                ims=elig_ims,
                axes=elig_axes)

            plt.pause(1e-8)

        pipeline.reset_()  # Reset state variables.
        network.connections['X', 'Y'].update_rule.e_trace = torch.zeros(
            784, n_classes)

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')
示例#14
0
    def create_network(self, norm=0.5, competitive_weight=-100.):
        self.norm = norm
        self.competitive_weight = competitive_weight
        self.time_max = 30
        dt = 1
        intensity = 127.5

        self.train_dataset = MNIST(
            PoissonEncoder(time=self.time_max, dt=dt),
            None,
            "MNIST",
            download=False,
            train=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
                )
            )

        # Hyperparameters
        n_filters = 25
        kernel_size = 12
        stride = 4
        padding = 0
        conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1
        per_class = int((n_filters * conv_size * conv_size) / 10)
        tc_trace = 20.  # grid search check
        tc_decay = 20.
        thresh = -52
        refrac = 5

        wmin = 0
        wmax = 1

        # Network
        self.network = Network(learning=True)
        self.GlobalMonitor = NetworkMonitor(self.network, state_vars=('v', 's', 'w'))


        self.input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

        self.output_layer = AdaptiveLIFNodes(
            n=n_filters * conv_size * conv_size,
            shape=(n_filters, conv_size, conv_size),
            traces=True,
            thres=thresh,
            trace_tc=tc_trace,
            tc_decay=tc_decay,
            theta_plus=0.05,
            tc_theta_decay=1e6)


        self.connection_XY = LocalConnection(
            self.input_layer,
            self.output_layer,
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            update_rule=PostPre,
            norm=norm, #1/(kernel_size ** 2),#0.4 * kernel_size ** 2,  # norm constant - check
            nu=[1e-4, 1e-2],
            wmin=wmin,
            wmax=wmax)

        # competitive connections
        w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size)
        for fltr1 in range(n_filters):
            for fltr2 in range(n_filters):
                if fltr1 != fltr2:
                    # change
                    for i in range(conv_size):
                        for j in range(conv_size):
                            w[fltr1, i, j, fltr2, i, j] = competitive_weight

        self.connection_YY = Connection(self.output_layer, self.output_layer, w=w)

        self.network.add_layer(self.input_layer, name='X')
        self.network.add_layer(self.output_layer, name='Y')

        self.network.add_connection(self.connection_XY, source='X', target='Y')
        self.network.add_connection(self.connection_YY, source='Y', target='Y')

        self.network.add_monitor(self.GlobalMonitor, name='Network')

        self.spikes = {}
        for layer in set(self.network.layers):
            self.spikes[layer] = Monitor(self.network.layers[layer], state_vars=["s"], time=self.time_max)
            self.network.add_monitor(self.spikes[layer], name="%s_spikes" % layer)
            #print('GlobalMonitor.state_vars:', self.GlobalMonitor.state_vars)

        self.voltages = {}
        for layer in set(self.network.layers) - {"X"}:
            self.voltages[layer] = Monitor(self.network.layers[layer], state_vars=["v"], time=self.time_max)
            self.network.add_monitor(self.voltages[layer], name="%s_voltages" % layer)
示例#15
0
    def test_gym_pipeline(self):
        # Build network.
        network = Network(dt=1.0)

        # Layers of neurons.
        inpt = Input(n=6552, traces=True)
        middle = LIFNodes(n=225, traces=True, thresh=-52.0 + torch.randn(225))
        out = LIFNodes(n=60, refrac=0, traces=True, thresh=-40.0)

        # Connections between layers.
        inpt_middle = Connection(source=inpt, target=middle, wmax=1e-2)
        middle_out = Connection(source=middle,
                                target=out,
                                wmax=0.5,
                                update_rule=m_stdp_et,
                                nu=2e-2,
                                norm=0.15 * middle.n)

        # Add all layers and connections to the network.
        network.add_layer(inpt, name='X')
        network.add_layer(middle, name='Y')
        network.add_layer(out, name='Z')
        network.add_connection(inpt_middle, source='X', target='Y')
        network.add_connection(middle_out, source='Y', target='Z')

        # Load SpaceInvaders environment.
        environment = GymEnvironment('SpaceInvaders-v0')
        environment.reset()

        # Build pipeline from specified components.
        for history_length in [3, 4, 5, 6]:
            for delta in [2, 3, 4]:
                p = Pipeline(network,
                             environment,
                             encoding=bernoulli,
                             action_function=select_multinomial,
                             output='Z',
                             time=1,
                             history_length=history_length,
                             delta=delta)

                assert p.action_function == select_multinomial
                assert p.history_length == history_length
                assert p.delta == delta

        # Checking assertion errors
        for time in [0, -1]:
            try:
                p = Pipeline(network,
                             environment,
                             encoding=bernoulli,
                             action_function=select_multinomial,
                             output='Z',
                             time=time,
                             history_length=2,
                             delta=4)
            except ValueError:
                pass

        for delta in [0, -1]:
            try:
                p = Pipeline(network,
                             environment,
                             encoding=bernoulli,
                             action_function=select_multinomial,
                             output='Z',
                             time=time,
                             history_length=2,
                             delta=delta)
            except ValueError:
                pass

        for output in ['K']:
            try:
                p = Pipeline(network,
                             environment,
                             encoding=bernoulli,
                             action_function=select_multinomial,
                             output=output,
                             time=time,
                             history_length=2,
                             delta=4)
            except ValueError:
                pass

        p = Pipeline(network,
                     environment,
                     encoding=bernoulli,
                     action_function=select_random,
                     output='Z',
                     time=1,
                     history_length=2,
                     delta=4,
                     save_interval=50,
                     render_interval=5)

        assert p.action_function == select_random
        assert p.encoding == bernoulli
        assert p.save_interval == 50
        assert p.render_interval == 5
        assert p.time == 1
from bindsnet.network.nodes import Input, LIFNodes, CurrentLIFNodes, AdaptiveLIFNodes, IzhikevichNodes
from bindsnet.network.topology import Connection
from bindsnet.network.monitors import Monitor

import matplotlib.pyplot as plt
from bindsnet.analysis.plotting import plot_voltages, plot_spikes

### initialisation
dt = 0.1
simulation_time = 500
if len(sys.argv) == 3:
    stimulation = float(sys.argv[2])
else:
    stimulation = 0.1

nodes_network = Network(dt=dt)
input_layer = Input(n=1, traces=True)
nodes_network.add_layer(layer=input_layer, name="Input")
input_monitor = Monitor(obj=input_layer, state_vars=("s"))
nodes_network.add_monitor(monitor=input_monitor, name="input monitor")

### input data
input_data = {
    "Input":
    stimulation * torch.bernoulli(
        0.1 * torch.ones(int(simulation_time / dt), input_layer.n)).byte()
}


### LIFNodes
def LIF(nodes_network):
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         lr=1e-2,
         lr_decay=1,
         time=350,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, lr, lr_decay, time, dt, theta_plus,
        theta_decay, intensity, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, lr, lr_decay, time, dt, theta_plus,
        theta_decay, intensity, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = Network()

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(n=n_neurons,
                                         traces=True,
                                         rest=-65.0,
                                         reset=-60.0,
                                         thresh=-52.0,
                                         refrac=5,
                                         decay=1e-2,
                                         trace_tc=5e-2,
                                         theta_plus=theta_plus,
                                         theta_decay=theta_decay)
        network.add_layer(output_layer, name='Y')

        w = 0.3 * torch.rand(784, n_neurons)
        input_connection = Connection(
            source=network.layers['X'],
            target=network.layers['Y'],
            w=w,
            update_rule=CompetitivePost,
            nu=[torch.zeros(784), lr * torch.ones(n_neurons)],
            wmin=0,
            wmax=1,
            norm=78.4)
        network.add_connection(input_connection, source='X', target='Y')

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images *= intensity

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, 10))
        rates = torch.zeros_like(torch.Tensor(n_neurons, 10))
        ngram_scores = {}
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers) - {'X'}:
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.lr[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, 10, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, 10, 2,
                                                   ngram_scores)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        if train:
            input_connection.update_rule.nu[
                1] = input_connection.update_rule.lr[1].clone()
            input_connection.update_rule.first = False

        retries = 0
        while spikes['Y'].get('s').sum() < 5 and retries < 3:
            retries += 1
            image *= 2
            sample = poisson(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            # _input = image.view(28, 28)
            # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections[('X', 'Y')].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            # square_assignments = get_square_assignments(assignments, n_sqrt)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  ngram_scores=ngram_scores,
                                  n=2)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            if train:
                network.save(os.path.join(params_path, model_name + '.pt'))
                path = os.path.join(
                    params_path, '_'.join(['auxiliary', model_name]) + '.pt')
                torch.save((assignments, proportions, rates, ngram_scores),
                           open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,lr,lr_decay,time,timestep,theta_plus,theta_decay,intensity,'
                    'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,lr,lr_decay,time,timestep,theta_plus,theta_decay,'
                    'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))
示例#18
0
def main(seed=0,
         n_train=60000,
         n_test=10000,
         kernel_size=(8, ),
         stride=(4, ),
         n_filters=25,
         n_full=100,
         padding=0,
         inhib=100,
         time=100,
         lr=1e-3,
         lr_decay=0.99,
         dt=1,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, n_train, kernel_size, stride, n_filters, n_full, padding, inhib,
        time, lr, lr_decay, dt, intensity, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, n_train, n_test, kernel_size, stride, n_filters, n_full,
            padding, inhib, time, lr, lr_decay, dt, intensity, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    input_shape = [28, 28]

    if kernel_size == input_shape:
        conv_size = [1, 1]
    else:
        conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1,
                     int((input_shape[1] - kernel_size[1]) / stride[1]) + 1)

    n_classes = 10
    total_kernel_size = int(np.prod(kernel_size))
    total_conv_size = int(np.prod(conv_size))
    n_neurons = n_filters * total_conv_size
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))

    # Build network.
    if train:
        network = Network()
        input_layer = Input(n=784, shape=(1, 1, 28, 28), traces=True)
        conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size,
                                       shape=(1, n_filters, *conv_size),
                                       thresh=-64.0,
                                       traces=True,
                                       theta_plus=0.05,
                                       refrac=0)
        conv_layer_prime = LIFNodes(n=n_filters * total_conv_size,
                                    shape=(1, n_filters, *conv_size),
                                    refrac=0,
                                    traces=True)
        conv_conn = Conv2dConnection(input_layer,
                                     conv_layer,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     update_rule=PostPre,
                                     norm=0.5 *
                                     int(np.sqrt(total_kernel_size)),
                                     nu=[0, lr],
                                     wmax=2.0)
        conv_conn_prime = Conv2dConnection(input_layer,
                                           conv_layer_prime,
                                           w=conv_conn.w,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           nu=[0, 0],
                                           wmax=2.0)

        w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1],
                                n_filters, conv_size[0], conv_size[1])
        for f in range(n_filters):
            for i in range(conv_size[0]):
                for j in range(conv_size[1]):
                    w[f, i, j, f, i, j] = 0

        w = w.view(n_filters * conv_size[0] * conv_size[1],
                   n_filters * conv_size[0] * conv_size[1])
        recurrent_conn = Connection(conv_layer, conv_layer, w=w)

        full_layer = DiehlAndCookNodes(n=n_full,
                                       thresh=-52.0,
                                       traces=True,
                                       theta_plus=0.05,
                                       refrac=0)
        full_layer_prime = LIFNodes(n=n_full, refrac=0)
        full_conn = Connection(conv_layer_prime,
                               full_layer,
                               update_rule=PostPre,
                               norm=0.2 * n_neurons,
                               nu=[0, 10 * lr],
                               wmax=1)
        full_conn_prime = Connection(conv_layer_prime,
                                     full_layer_prime,
                                     0,
                                     wmax=1)

        w = -inhib * (torch.ones(n_full, n_full) -
                      torch.diag(torch.ones(n_full)))
        recurrent_conn2 = Connection(full_layer, full_layer, w=w)

        network.add_layer(input_layer, name='X')
        network.add_layer(conv_layer, name='Y')
        network.add_layer(conv_layer_prime, name='Y_')
        network.add_layer(full_layer, name='Z')
        network.add_layer(full_layer_prime, name='Z_')

        network.add_connection(conv_conn, source='X', target='Y')
        network.add_connection(conv_conn_prime, source='X', target='Y_')
        network.add_connection(recurrent_conn, source='Y', target='Y')
        network.add_connection(full_conn, source='Y_', target='Z')
        network.add_connection(full_conn_prime, source='Y_', target='Z_')
        network.add_connection(recurrent_conn2, source='Z', target='Z')

        # Voltage recording for excitatory and inhibitory layers.
        voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
        network.add_monitor(voltage_monitor, name='output_voltage')
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))

        for connection in network.connections.values():
            connection.update_rule = NoOp(connection, connection.nu)
            connection.theta_decay = 0
            connection.theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_full)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_full))
        proportions = torch.zeros_like(torch.Tensor(n_full, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_full, n_classes))
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs')
        logreg_model.coef_ = np.zeros([n_classes, n_full])
        logreg_model.intercept_ = np.zeros(n_classes)
        logreg_model.classes_ = np.arange(n_classes)
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, logreg_coef, logreg_intercept = torch.load(
            open(path, 'rb'))
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs')
        logreg_model.coef_ = logreg_coef
        logreg_model.intercept_ = logreg_intercept
        logreg_model.classes_ = np.arange(n_classes)

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'logreg': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_ims = None
    inpt_axes = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    weights_im2 = None
    assigns_im = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print('Progress: %d / %d (%.4f seconds)' %
                  (i, n_examples, t() - start))
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          logreg=logreg_model)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates,
                                logreg_model.coef_, logreg_model.intercept_),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Refit logistic regression model.
                logreg_model = logreg_fit(spike_record, current_labels,
                                          logreg_model)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = bernoulli(datum=image, time=time, dt=dt,
                           max_prob=0.5).unsqueeze(1).unsqueeze(1)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Z'].get('s').sum() < 5 and retries < 3:
            retries += 1
            sample = bernoulli(datum=image,
                               time=time,
                               dt=dt,
                               max_prob=0.5 +
                               retries * 0.15).unsqueeze(1).unsqueeze(1)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Z'].get('s').view(time, -1)

        # Optionally plot various simulation information.
        if plot:
            _input = inpts['X'].view(time, 784).sum(0).view(28, 28)
            w = network.connections['X', 'Y'].w
            w2 = network.connections['Y_', 'Z'].w
            _spikes = {
                'X': spikes['X'].get('s').view(28**2, time),
                'Y': spikes['Y'].get('s').view(n_neurons, time),
                'Y_': spikes['Y_'].get('s').view(n_neurons, time),
                'Z': spikes['Z'].get('s').view(n_full, time),
                'Z_': spikes['Z_'].get('s').view(n_full, time)
            }
            square_assignments = get_square_assignments(assignments, n_sqrt)

            inpt_axes, inpt_ims = plot_input(image.view(28, 28),
                                             _input,
                                             label=labels[i],
                                             ims=inpt_ims,
                                             axes=inpt_axes)
            spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_conv2d_weights(w, im=weights_im, wmax=0.2)
            weights_im2 = plot_weights(w2, im=weights_im2, wmax=1)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  logreg=logreg_model)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path,
                                '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((assignments, proportions, rates, logreg_model.coef_,
                        logreg_model.intercept_), open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['logreg']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['logreg'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                columns = [
                    'seed', 'n_train', 'kernel_size', 'stride', 'n_filters',
                    'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt',
                    'intensity', 'update_interval', 'mean_all_activity',
                    'mean_proportion_weighting', 'mean_logreg',
                    'max_all_activity', 'max_proportion_weighting',
                    'max_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)
            else:
                columns = [
                    'seed', 'n_train', 'n_test', 'kernel_size', 'stride',
                    'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay',
                    'dt', 'intensity', 'update_interval', 'mean_all_activity',
                    'mean_proportion_weighting', 'mean_logreg',
                    'max_all_activity', 'max_proportion_weighting',
                    'max_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))
示例#19
0
    def __init__(
        self,
        dt=1.0,
        A=1.0,
        PN_KC_weight=0.25,
        KC_EN_weight=2.0,
        min_weight=0.0001,
        PN_thresh=-40.0,
        KC_thresh=-25.0,
        EN_thresh=-40.0,
    ):

        self.landmark_guidance = Network(dt=dt)

        # layers
        self.input_layer = Input(n=360, shape=(10, 36))
        self.PN = Izhikevich(n=360,
                             traces=True,
                             tc_decay=10.0,
                             thresh=PN_thresh,
                             rest=-60.0,
                             C=100,
                             a=0.3,
                             b=-0.2,
                             c=-65,
                             d=8,
                             k=2)
        self.KC = Izhikevich(n=20000,
                             traces=True,
                             tc_decay=10.0,
                             thresh=KC_thresh,
                             rest=-85.0,
                             C=4,
                             a=0.01,
                             b=-0.3,
                             c=-65,
                             d=8,
                             k=0.035)
        self.EN = Izhikevich(n=1,
                             traces=True,
                             tc_decay=10.0,
                             thresh=EN_thresh,
                             rest=-60.0,
                             C=100,
                             a=0.3,
                             b=-0.2,
                             c=-65,
                             d=8,
                             k=2)
        self.landmark_guidance.add_layer(layer=self.input_layer, name="Input")
        self.landmark_guidance.add_layer(layer=self.PN, name="PN")
        self.landmark_guidance.add_layer(layer=self.KC, name="KC")
        self.landmark_guidance.add_layer(layer=self.EN, name="EN")

        # connections
        connection_weight = torch.zeros(self.input_layer.n,
                                        self.PN.n).fill_diagonal_(1)
        self.input_PN = Connection(source=self.input_layer,
                                   target=self.PN,
                                   w=connection_weight)

        connection_weight = torch.zeros(self.PN.n, self.KC.n).t()
        connection_weight = connection_weight.scatter_(
            1,
            torch.tensor([
                np.random.choice(self.PN.n, size=10, replace=False)
                for i in range(self.KC.n)
            ]).long(), PN_KC_weight)
        self.PN_KC = AllToAllConnection(source=self.PN,
                                        target=self.KC,
                                        w=connection_weight.t(),
                                        tc_synaptic=3.0,
                                        phi=0.93)

        self.KC_EN = AllToAllConnection(source=self.KC,
                                        target=self.EN,
                                        w=torch.ones(self.KC.n, self.EN.n) *
                                        2.0,
                                        tc_synaptic=8.0,
                                        phi=8.0)

        self.landmark_guidance.add_connection(connection=self.input_PN,
                                              source="Input",
                                              target="PN")
        self.landmark_guidance.add_connection(connection=self.PN_KC,
                                              source="PN",
                                              target="KC")
        self.landmark_guidance.add_connection(connection=self.KC_EN,
                                              source="KC",
                                              target="EN")

        # learning rule
        self.KC_EN.update_rule = STDP(connection=self.KC_EN,
                                      nu=(-A, -A),
                                      tc_eligibility_trace=40.0,
                                      tc_plus=15,
                                      tc_minus=15,
                                      tc_reward=20.0,
                                      min_weight=min_weight)

        # monitors
        input_monitor = Monitor(obj=self.input_layer, state_vars=("s"))
        PN_monitor = Monitor(obj=self.PN, state_vars=("s", "v"))
        KC_monitor = Monitor(obj=self.KC, state_vars=("s", "v"))
        EN_monitor = Monitor(obj=self.EN, state_vars=("s", "v"))
        self.landmark_guidance.add_monitor(monitor=input_monitor,
                                           name="Input monitor")
        self.landmark_guidance.add_monitor(monitor=PN_monitor,
                                           name="PN monitor")
        self.landmark_guidance.add_monitor(monitor=KC_monitor,
                                           name="KC monitor")
        self.landmark_guidance.add_monitor(monitor=EN_monitor,
                                           name="EN monitor")

        # plots
        self.plots = {}

        # number of EN spikes during the simulation
        self.nb_spikes_EN = 0
示例#20
0
    def test_weights(self, conn_type, shape_a, shape_b, shape_w, *args,
                     **kwargs):
        print("Testing:", conn_type)
        time = 100
        weights = [None, torch.Tensor(*shape_w)]
        wmins = [
            -np.inf,
            0,
            torch.zeros(*shape_w),
            torch.zeros(*shape_w).masked_fill(
                torch.bernoulli(torch.rand(*shape_w)) == 1, -np.inf),
        ]
        wmaxes = [
            np.inf,
            0,
            torch.ones(*shape_w),
            torch.randn(*shape_w).masked_fill(
                torch.bernoulli(torch.rand(*shape_w)) == 1, np.inf),
        ]
        update_rule = kwargs.get("update_rule", None)
        for w in weights:
            for wmin in wmins:
                for wmax in wmaxes:

                    ### Conditional checks ###
                    # WeightDependentPostPre does not handle infinite ranges
                    if ((torch.tensor(wmin, dtype=torch.float32)
                         == -np.inf).any() or
                        (torch.tensor(wmax, dtype=torch.float32) == np.inf
                         ).any()) and update_rule == WeightDependentPostPre:
                        continue

                    # Rmax only supported for Connection & LocalConnection
                    elif (not (conn_type == Connection
                               or conn_type == LocalConnection)
                          and update_rule == Rmax):
                        return

                    print(
                        f"- w: {type(w).__name__}, "
                        f"wmin: {type(wmax).__name__}, wmax: {type(wmax).__name__}"
                    )
                    if kwargs.get("update_rule") == Rmax:
                        l_a = SRM0Nodes(shape=shape_a,
                                        traces=True,
                                        traces_additive=True)
                        l_b = SRM0Nodes(shape=shape_b,
                                        traces=True,
                                        traces_additive=True)
                    else:
                        l_a = LIFNodes(shape=shape_a,
                                       traces=True,
                                       traces_additive=True)
                        l_b = LIFNodes(shape=shape_b,
                                       traces=True,
                                       traces_additive=True)

                    ### Create network ###
                    network = Network(dt=1.0)
                    network.add_layer(Input(n=100,
                                            traces=True,
                                            traces_additive=True),
                                      name="input")
                    network.add_layer(l_a, name="a")
                    network.add_layer(l_b, name="b")

                    network.add_connection(
                        conn_type(l_a,
                                  l_b,
                                  w=w,
                                  wmin=wmin,
                                  wmax=wmax,
                                  *args,
                                  **kwargs),
                        source="a",
                        target="b",
                    )
                    network.add_connection(
                        Connection(
                            wmin=0,
                            wmax=1,
                            source=network.layers["input"],
                            target=network.layers["a"],
                            **kwargs,
                        ),
                        source="input",
                        target="a",
                    )

                    ### Run network ###
                    network.run(
                        inputs={
                            "input": torch.bernoulli(torch.rand(time,
                                                                100)).byte()
                        },
                        time=time,
                        reward=1,
                    )
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=250,
         lr=1e-2,
         lr_decay=1,
         time=100,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         intensity=1,
         progress_interval=10,
         update_interval=100,
         plot=False,
         train=True,
         gpu=False,
         no_inhib=False,
         no_theta=False):

    assert n_train % update_interval == 0, 'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus,
        theta_decay, intensity, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt,
        theta_plus, theta_decay, intensity, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10
    per_class = int(n_neurons / n_classes)

    # Build network.
    if train:
        network = Network()

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(n=n_neurons,
                                         traces=True,
                                         rest=0,
                                         reset=0,
                                         thresh=5,
                                         refrac=0,
                                         decay=1e-2,
                                         trace_tc=5e-2,
                                         theta_plus=theta_plus,
                                         theta_decay=theta_decay)
        network.add_layer(output_layer, name='Y')

        w = 0.3 * torch.rand(784, n_neurons)
        input_connection = Connection(source=network.layers['X'],
                                      target=network.layers['Y'],
                                      w=w,
                                      update_rule=WeightDependentPostPre,
                                      nu=[0, lr],
                                      wmin=0,
                                      wmax=1,
                                      norm=78.4)
        network.add_connection(input_connection, source='X', target='Y')

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

        if no_inhib:
            del network.connections['Y', 'Y']

        if no_theta:
            network.layers['Y'].theta = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True, shuffle=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images *= intensity
    labels = labels.long()

    monitors = {}
    for layer in set(network.layers):
        if 'v' in network.layers[layer].__dict__:
            monitors[layer] = Monitor(network.layers[layer],
                                      state_vars=['s', 'v'],
                                      time=time)
        else:
            monitors[layer] = Monitor(network.layers[layer],
                                      state_vars=['s'],
                                      time=time)

        network.add_monitor(monitors[layer], name=layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    voltage_ims = None
    voltage_axes = None
    weights_im = None
    theta_im = None

    unclamps = {}
    for label in range(n_classes):
        unclamp = torch.ones(n_neurons).byte()
        unclamp[label * per_class:(label + 1) * per_class] = 0
        unclamps[label] = unclamp

    predictions = torch.zeros(n_examples)
    corrects = torch.zeros(n_examples)
    spike_record = torch.zeros(n_examples, n_neurons)

    flag = False

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0 and train:
            network.save(os.path.join(params_path, model_name + '.pt'))
            network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if not flag:
                w = -inhib * (torch.ones(n_neurons, n_neurons) -
                              torch.diag(torch.ones(n_neurons)))
                recurrent_connection = Connection(source=network.layers['Y'],
                                                  target=network.layers['Y'],
                                                  w=w,
                                                  wmin=-inhib,
                                                  wmax=0)
                network.add_connection(recurrent_connection,
                                       source='Y',
                                       target='Y')

            flag = True

        # Get next input sample.
        image = images[i % len(images)]
        label = labels[i % len(images)].item()
        sample = bernoulli(datum=image, time=time, dt=dt, max_prob=1)
        inpts = {'X': sample}

        # Run the network on the input.
        if train:
            network.run(inpts=inpts, time=time, unclamp={'Y': unclamps[label]})
        else:
            network.run(inpts=inpts, time=time)

        output = monitors['Y'].get('s')
        summed_neurons = output.sum(dim=1).view(n_classes, per_class)
        summed_classes = summed_neurons.sum(dim=1).long()
        prediction = torch.argmax(summed_classes).item()
        correct = prediction == label

        predictions[i] = prediction
        corrects[i] = int(correct)

        spike_record[i] = output.float().sum(dim=1)

        # Optionally plot various simulation information.
        if plot and i % update_interval == 0:
            # _input = image.view(28, 28)
            # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            # v = {'Y': monitors['Y'].get('v')}

            s = {layer: monitors[layer].get('s') for layer in monitors}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            theta = network.layers['Y'].theta.view(per_class, per_class)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            # voltage_ims, voltage_axes = plot_voltages(v, ims=voltage_ims, axes=voltage_axes)

            spike_ims, spike_axes = plot_spikes(s,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)

            # if theta_im is None:
            #     theta_im = plt.matshow(theta)
            #     cax = plt.colorbar()
            # else:
            #     theta_im.set_data(theta)
            #     cax.set_clim(theta.min(), theta.max())

            plt.pause(1e-1)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        network.save(os.path.join(params_path, model_name + '.pt'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    accuracy = torch.mean(corrects).item() * 100

    print(f'\nAccuracy: {accuracy}\n')

    to_write = params + [accuracy] if train else test_params + [accuracy]
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,'
                    'theta_decay,intensity,progress_interval,update_interval,accuracy\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,'
                    'theta_plus,theta_decay,intensity,progress_interval,update_interval,accuracy\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusion = confusion_matrix(labels, predictions)

    if plot:
        plt.ioff()
        plt.matshow(confusion)
        plt.show()

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusion, os.path.join(confusion_path, f))
示例#22
0
from bindsnet.encoding import PoissonEncoder
from bindsnet.network import Network
from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.network.monitors import Monitor
from bindsnet.analysis.plotting import plot_spikes, plot_voltages
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.learning import PostPre
from bindsnet.datasets import MNIST
from tqdm import tqdm


time = 500


network = Network(dt=1, learning=True)

layerIn = Input(n=28*28, traces=True)
layer1 = LIFNodes(n=100, traces=True)
layer2 = LIFNodes(n=100, traces=True)
layerOut = LIFNodes(n=10, traces=True)

con1 = Connection(source=layerIn, target=layer1, update_rule=PostPre, nu=(1e-4, 1e-2))
con2 = Connection(source=layer1, target=layer2, update_rule=PostPre, nu=(1e-4, 1e-2))
con3 = Connection(source=layer2, target=layerOut, update_rule=PostPre, nu=(1e-4, 1e-2))

outMonitor = Monitor(
    obj=layerOut,
    state_vars=("s", "v"),  # Record spikes and voltages.
    time=time,  # Length of simulation (if known ahead of time).
)
示例#23
0
plot = args.plot
gpu = True

input_size = rf_size
tnn_layer_sz = 50
num_timesteps = 8
# tnn_thresh = 80
max_weight = num_timesteps
# num_winners = 40 #tnn_layer_sz

time = num_timesteps

torch.manual_seed(seed)

# build network:
network = Network(dt=1)
input_layer = Input(n=input_size)
tnn_layer_1 = TemporalNeurons( \
	n=tnn_layer_sz, \
	timesteps=num_timesteps, \
	threshold=30, \
	num_winners=4\
	)

tnn_layer_2 = TemporalNeurons( \
	n=tnn_layer_sz, \
	timesteps=num_timesteps, \
	threshold=30, \
	num_winners=1\
	)
示例#24
0
def main():

    # Build network.
    network = Network(dt=dt)

    # Input Layer
    inpt = Input(n=dim * dim, shape=[1, 1, 1, dim, dim], traces=True)

    # Hidden Layer
    middle = LIFNodes(n=neurons, traces=True)

    # Ouput Layer
    out = LIFNodes(n=moveChoices, refrac=0, traces=True)

    # Connections from input layer to hidden layer
    inpt_middle = Connection(source=inpt, target=middle, wmin=0, wmax=1)

    # Connections from hidden layer to output layer
    middle_out = Connection(
        source=middle,
        target=out,
        wmin=0,  # minimum weight value
        wmax=1,  # maximum weight value
        update_rule=MSTDPET,  # learning rule
        nu=1e-1,  # learning rate
        norm=0.5 * middle.n,  # normalization
    )

    # Recurrent connection, retaining data within the hidden layer
    recurrent = Connection(
        source=middle,
        target=middle,
        wmin=0,  # minimum weight value
        wmax=1,  # maximum weight value
        update_rule=PostPre,  # learning rule
        nu=1e-1,  # learning rate
        norm=5e-3 * middle.n,  # normalization
    )

    # Add all layers and connections to the network.
    network.add_layer(inpt, name=LAYER1)
    network.add_layer(middle, name=LAYER2)
    network.add_layer(out, name=LAYER3)
    network.add_connection(inpt_middle, source=LAYER1, target=LAYER2)
    network.add_connection(middle_out, source=LAYER2, target=LAYER3)
    network.add_connection(recurrent, source=LAYER2, target=LAYER2)
    network.to(DEVICE)

    # Add monitors
    # network.add_monitor(Monitor(network.layers["Hidden"], ["s"], time=granularity), "Hidden")
    # network.add_monitor(Monitor(network.layers["Output"], ["s"], time=granularity), "Output")
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(
            network.layers[layer],
            state_vars=["s"],
            time=int(granularity / dt),
            device=DEVICE,
        )
        network.add_monitor(spikes[layer], name=layer)

    # Load the Dot Simultation environment.
    environment = DotSimulator(
        steps,
        decay=decay,
        herrs=herrs,
        diag=diag,
        randr=randr,
        write=write,
        mute=mute,
        bound_hand=boundh,
        fit_func=fit_func,
        allow_stay=allow_stay,
        pandas=pandas,
        fpath=OUT_FILE_PATH,
    )
    environment.reset()

    print("Training: ")
    rewFile = genFileName("rew", "train")
    perfFile = genFileName("perf", "train")
    environment.addFileSuffix("train")
    runSimulator(
        network,
        environment,
        spikes,
        episodes=trn_eps,
        gran=granularity,
        rfname=rewFile,
        pfname=perfFile,
    )

    # Freeze learning
    network.learning = False

    print("Testing: ")
    rewFile = genFileName("rew", "test")
    perfFile = genFileName("perf", "test")
    environment.changeFileSuffix("train", "test")
    runSimulator(
        network,
        environment,
        spikes,
        episodes=tst_eps,
        gran=granularity,
        rfname=rewFile,
        pfname=perfFile,
    )
示例#25
0
import torch

from bindsnet.network import Network
from bindsnet.pipeline import Pipeline
from bindsnet.encoding import bernoulli
from bindsnet.network.topology import Connection
from bindsnet.environment import GymEnvironment
from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.pipeline.action import select_softmax

# Build network.
network = Network(dt=1.0)

# Layers of neurons.
inpt = Input(n=80 * 80, shape=[80, 80], traces=True)
middle = LIFNodes(n=100, traces=True)
out = LIFNodes(n=4, refrac=0, traces=True)

# Connections between layers.
inpt_middle = Connection(source=inpt, target=middle, wmin=0, wmax=1e-1)
middle_out = Connection(source=middle, target=out, wmin=0, wmax=1)

# Add all layers and connections to the network.
network.add_layer(inpt, name="Input Layer")
network.add_layer(middle, name="Hidden Layer")
network.add_layer(out, name="Output Layer")
network.add_connection(inpt_middle,
                       source="Input Layer",
                       target="Hidden Layer")
network.add_connection(middle_out,
                       source="Hidden Layer",
示例#26
0
def main(seed=0, n_train=60000, n_test=10000, c_low=1, c_high=25, p_low=0.5, kernel_size=(16,), stride=(2,),
         n_filters=25, crop=4, lr=0.01, lr_decay=1, time=100, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1,
         norm=0.2, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, c_low, c_high, p_low, time, dt,
        theta_plus, theta_decay, intensity, norm, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, n_test, c_low, c_high, p_low, time, dt,
            theta_plus, theta_decay, intensity, norm, progress_interval, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    side_length = 28 - crop * 2
    n_inpt = side_length ** 2
    input_shape = [side_length, side_length]
    n_examples = n_train if train else n_test
    n_classes = 10

    if _pair(kernel_size) == input_shape:
        conv_size = [1, 1]
    else:
        conv_size = (int((input_shape[0] - _pair(kernel_size)[0]) / _pair(stride)[0]) + 1,
                     int((input_shape[1] - _pair(kernel_size)[1]) / _pair(stride)[1]) + 1)

    # Build network.
    if train:
        network = Network()

        input_layer = Input(n=n_inpt, traces=True, trace_tc=5e-2)
        output_layer = DiehlAndCookNodes(
            n=n_filters * conv_size[0] * conv_size[1], traces=True, rest=-65.0, reset=-60.0,
            thresh=-52.0, refrac=5, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay
        )
        input_output_conn = LocallyConnectedConnection(
            input_layer, output_layer, kernel_size=kernel_size, stride=stride, n_filters=n_filters,
            nu=[0, lr], update_rule=PostPre, wmin=0, wmax=1, norm=norm, input_shape=input_shape
        )

        w = torch.zeros(n_filters, *conv_size, n_filters, *conv_size)
        for fltr1 in range(n_filters):
            for fltr2 in range(n_filters):
                if fltr1 != fltr2:
                    for j in range(conv_size[0]):
                        for k in range(conv_size[1]):
                            x1, y1 = fltr1 // np.sqrt(n_filters), fltr1 % np.sqrt(n_filters)
                            x2, y2 = fltr2 // np.sqrt(n_filters), fltr2 % np.sqrt(n_filters)

                            w[fltr1, j, k, fltr2, j, k] = max(-c_high, -c_low * np.sqrt(euclidean([x1, y1], [x2, y2])))

        w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1])
        recurrent_conn = Connection(output_layer, output_layer, w=w)

        plt.matshow(w)
        plt.colorbar()

        network.add_layer(input_layer, name='X')
        network.add_layer(output_layer, name='Y')
        network.add_connection(input_output_conn, source='X', target='Y')
        network.add_connection(recurrent_conn, source='Y', target='Y')
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
        )
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    conv_size = network.connections['X', 'Y'].conv_size
    locations = network.connections['X', 'Y'].locations
    conv_prod = int(np.prod(conv_size))
    n_neurons = n_filters * conv_prod

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity
    images = images[:, crop:-crop, crop:-crop]

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, 10))
        rates = torch.zeros_like(torch.Tensor(n_neurons, 10))
        ngram_scores = {}
    else:
        path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(open(path, 'rb'))

    if train:
        best_accuracy = 0

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {
        scheme: torch.Tensor().long() for scheme in curves.keys()
    }

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time)
        network.add_monitor(spikes[layer], name=f'{layer}_spikes')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None

    # Calculate linear increase every update interval.
    if train:
        n_increase = int(p_low * n_examples) / update_interval
        increase = (c_high - c_low) / n_increase
        increases = 0
        inhib = c_low

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

                if increases < n_increase:
                    inhib = inhib + increase

                    print(f'\nIncreasing inhibition to {inhib}.\n')

                    w = torch.zeros(n_filters, *conv_size, n_filters, *conv_size)
                    for fltr1 in range(n_filters):
                        for fltr2 in range(n_filters):
                            if fltr1 != fltr2:
                                for j in range(conv_size[0]):
                                    for k in range(conv_size[1]):
                                        x1, y1 = fltr1 // np.sqrt(n_filters), fltr1 % np.sqrt(n_filters)
                                        x2, y2 = fltr2 // np.sqrt(n_filters), fltr2 % np.sqrt(n_filters)

                                        w[fltr1, j, k, fltr2, j, k] = max(-c_high, -c_low * np.sqrt(euclidean([x1, y1], [x2, y2])))

                    w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1])
                    network.connections['Y', 'Y'].w = w

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i % len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(
                curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments,
                proportions=proportions, ngram_scores=ngram_scores, n=2
            )
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print('New best accuracy! Saving network parameters to disk.')

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb'))

                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(spike_record, current_labels, 10, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record, current_labels, 10, 2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % update_interval].contiguous().view(-1)
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 5 and retries < 3:
            retries += 1
            image *= 2
            sample = poisson(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            _spikes = {
                'X': spikes['X'].get('s').view(side_length ** 2, time),
                'Y': spikes['Y'].get('s').view(n_filters * conv_prod, time)
            }

            spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes)
            weights_im = plot_locally_connected_weights(
                network.connections[('X', 'Y')].w, n_filters, kernel_size, conv_size, locations, side_length, im=weights_im
            )

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i % len(images)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(
        curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments,
        proportions=proportions, ngram_scores=ngram_scores, n=2
    )
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [
        np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']),
        np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,c_low,c_high,p_low,time,timestep,theta_plus,'
                    'theta_decay,intensity,norm,progress_interval,update_interval,mean_all_activity,'
                    'mean_proportion_weighting,mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,n_test,c_low,c_high,p_low,time,timestep,'
                    'theta_plus,theta_decay,intensity,norm,progress_interval,update_interval,mean_all_activity,'
                    'mean_proportion_weighting,mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat([labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))
示例#27
0
def main(n_hidden=100, time=100, lr=5e-2, plot=False, gpu=False):
    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    network = Network()

    input_layer = Input(n=784, traces=True)
    hidden_layer = DiehlAndCookNodes(n=n_hidden, rest=0, reset=0, thresh=1, traces=True)
    output_layer = LIFNodes(n=784, rest=0, reset=0, thresh=1, traces=True)
    input_hidden_connection = Connection(
        input_layer, hidden_layer, wmin=0, wmax=1, norm=75, update_rule=Hebbian, nu=[0, lr]
    )
    hidden_hidden_connection = Connection(
        hidden_layer, hidden_layer, wmin=-500, wmax=0,
        w=-500 * torch.zeros(n_hidden, n_hidden) - torch.diag(torch.ones(n_hidden))
    )
    hidden_output_connection = Connection(
        hidden_layer, input_layer, wmin=0, wmax=1, norm=15, update_rule=Hebbian, nu=[lr, 0]
    )

    network.add_layer(input_layer, name='X')
    network.add_layer(hidden_layer, name='H')
    network.add_layer(output_layer, name='Y')
    network.add_connection(input_hidden_connection, source='X', target='H')
    network.add_connection(hidden_hidden_connection, source='H', target='H')
    network.add_connection(hidden_output_connection, source='H', target='Y')

    for layer in network.layers:
        monitor = Monitor(
            obj=network.layers[layer], state_vars=('s',), time=time
        )
        network.add_monitor(monitor, name=layer)

    dataset = MNIST(
        path=os.path.join(ROOT_DIR, 'data', 'MNIST'), shuffle=True, download=True
    )

    images, labels = dataset.get_train()
    images = images.view(-1, 784)
    images /= 4
    labels = labels.long()

    spikes_ims = None
    spikes_axes = None
    weights1_im = None
    weights2_im = None
    inpt_ims = None
    inpt_axes = None

    for image, label in zip(images, labels):
        spikes = poisson(image, time=time, dt=network.dt)
        inpts = {'X': spikes}
        clamp = {'Y': spikes}
        unclamp = {'Y': ~spikes}

        network.run(
            inpts=inpts, time=time, clamp=clamp, unclamp=unclamp
        )

        if plot:
            spikes = {
                l: network.monitors[l].get('s') for l in network.layers
            }
            spikes_ims, spikes_axes = plot_spikes(
                spikes, ims=spikes_ims, axes=spikes_axes
            )

            inpt = spikes['X'].float().mean(1).view(28, 28)
            rcstn = spikes['Y'].float().mean(1).view(28, 28)

            inpt_axes, inpt_ims = plot_input(
                inpt, rcstn, label=label, axes=inpt_axes, ims=inpt_ims
            )

            w1 = get_square_weights(
                network.connections['X', 'H'].w.view(784, n_hidden), int(np.ceil(np.sqrt(n_hidden))), 28
            )
            w2 = get_square_weights(
                network.connections['H', 'Y'].w.view(n_hidden, 784).t(), int(np.ceil(np.sqrt(n_hidden))), 28
            )

            weights1_im = plot_weights(
                w1, wmin=0, wmax=1, im=weights1_im
            )
            weights2_im = plot_weights(
                w2, wmin=0, wmax=1, im=weights2_im
            )

            plt.pause(0.01)
示例#28
0
def prepare_network():
    global net

    net = Network()

    for g_size in G_SIZES:
        s1_g_size = Input(shape=(len(G_THETAS), IMG_SHAPE[0], IMG_SHAPE[1],), traces=True)
        net.add_layer(layer=s1_g_size, name=s1_name(g_size))

        c1_g_size = LIFNodes(shape=(len(G_THETAS), IMG_SHAPE[0] // 2, IMG_SHAPE[1] // 2,), thresh=-64, traces=True)
        net.add_layer(layer=c1_g_size, name=c1_name(g_size))

        max_pool_con = MaxPool2dConnection(s1_g_size, c1_g_size, kernel_size=2, stride=2, decay=0.0)
        net.add_connection(max_pool_con, s1_name(g_size), c1_name(g_size))

    for f_idx in range(N_SIZE_FEATURES):
        for g_size in G_SIZES:
            s2_nodes = LIFNodes(shape=(1, IMG_SHAPE[0] // 2, IMG_SHAPE[1] // 2,), traces=True, tc_decay=50.0, thresh=-55, trace_scale=0.2)
            net.add_layer(layer=s2_nodes, name=s2_name(f_idx, g_size))

            conv_con = Conv2dConnection(net.layers[c1_name(g_size)], s2_nodes, 5, padding=2,  nu=[0.0006, 0.008], update_rule=PostPre, wmin=0, wmax=1)
            net.add_connection(conv_con, c1_name(g_size), s2_name(f_idx, g_size))

            c2_nodes = LIFNodes(shape=(1, IMG_SHAPE[0] // 4, IMG_SHAPE[1] // 4,), thresh=-64, traces=True)
            net.add_layer(layer=c2_nodes, name=c2_name(f_idx, g_size))

            max_pool_con = MaxPool2dConnection(s2_nodes, c2_nodes, kernel_size=2, stride=2, decay=0.0)
            net.add_connection(max_pool_con, s2_name(f_idx, g_size), c2_name(f_idx, g_size))

    d1 = LIFNodes(n=DEEP_LAYERS_N, traces=True)
    net.add_layer(layer=d1, name=d1_name())
    for f_idx in range(N_SIZE_FEATURES):
        for g_size in G_SIZES:
            src_layer = net.layers[c2_name(f_idx, g_size)]
            conn = Connection(
                source=src_layer,
                target=d1,
                w=0.05 + 0.1 * torch.randn(src_layer.n, d1.n),
                update_rule=PostPre
            )
            net.add_connection(conn, c2_name(f_idx, g_size), d1_name())

    d2 = LIFNodes(n=DEEP_LAYERS_N, traces=True)
    net.add_layer(layer=d2, name=d2_name())
    
    d1_d2_conn = Connection(
        source=d1,
        target=d2,
        w=0.05 + 0.1 * torch.randn(d1.n, d2.n),
        update_rule=PostPre
    )
    net.add_connection(d1_d2_conn, d1_name(), d2_name())

    r = LIFNodes(n=len(TARGETS), traces=True)
    net.add_layer(layer=r, name="R")

    d2_r_conn = Connection(
        source=d2,
        target=r,
        w=0.05 + 0.05 * torch.randn(d2.n, r.n),
        update_rule=PostPre
    )
    net.add_connection(d2_r_conn, d2_name(), r_name())

    r_rec = Connection(
        source=r,
        target=r,
        w=0.5 * (torch.eye(r.n) - 1),
        decay=0,
    )
    net.add_connection(r_rec, r_name(), r_name())

    net.add_monitor(
        Monitor(net.layers[r_name()], ["s"]),
        "result"
    )
示例#29
0
import torch
import matplotlib.pyplot as plt

from bindsnet.network import Network
from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import MeanFieldConnection
from bindsnet.network.monitors import Monitor
from bindsnet.analysis.plotting import plot_spikes, plot_weights

network = Network()

X = Input(n=100)
Y = LIFNodes(n=100)

C = MeanFieldConnection(source=X, target=Y, norm=100.0)

M_X = Monitor(X, state_vars=['s'])
M_Y = Monitor(Y, state_vars=['s', 'v'])
M_C = Monitor(C, state_vars=['w'])

network.add_layer(X, name='X')
network.add_layer(Y, name='Y')
network.add_connection(C, source='X', target='Y')
network.add_monitor(M_X, 'M_X')
network.add_monitor(M_Y, 'M_Y')
network.add_monitor(M_C, 'M_C')

spikes = torch.bernoulli(torch.rand(1000, 100))
inpts = {'X': spikes}

network.run(inpts=inpts, time=1000)
示例#30
0
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1,
         theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False,
         train=True, gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr_decay, time, dt,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, lr_decay, time, dt,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = RealInput(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(
            n=n_neurons, traces=True, rest=0, reset=1, thresh=1, refrac=0,
            decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay
        )
        network.add_layer(output_layer, name='Y')

        readout = IFNodes(n=n_classes, reset=0, thresh=1)
        network.add_layer(readout, name='Z')

        w = torch.rand(784, n_neurons)
        input_connection = Connection(
            source=input_layer, target=output_layer, w=w,
            update_rule=MSTDP, nu=lr, wmin=0, wmax=1, norm=78.4
        )
        network.add_connection(input_connection, source='X', target='Y')

        w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons)))
        recurrent_connection = Connection(
            source=output_layer, target=output_layer, w=w, wmin=-inhib, wmax=0
        )
        network.add_connection(recurrent_connection, source='Y', target='Y')

        readout_connection = Connection(
            source=network.layers['Y'], target=readout, w=torch.rand(n_neurons, n_classes), norm=10
        )
        network.add_connection(readout_connection, source='Y', target='Z')

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
        )
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    labels = labels.long()

    spikes = {}
    for layer in set(network.layers) - {'X'}:
        spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    weights2_im = None
    assigns_im = None
    perf_ax = None

    predictions = torch.zeros(update_interval).long()

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

            if i > 0 and train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Get next input sample.
        image = images[i % len(images)]

        # Run the network on the input.
        for j in range(time):
            readout = network.layers['Z'].s

            if readout[labels[i % len(labels)]]:
                network.run(inpts={'X': image.unsqueeze(0)}, time=1, reward=1, a_minus=0, a_plus=1)
            else:
                network.run(inpts={'X': image.unsqueeze(0)}, time=1, reward=0)

        label = spikes['Z'].get('s').sum(1).argmax()
        predictions[i % update_interval] = label.long()

        if i > 0 and i % update_interval == 0:
            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i % len(images)]

            accuracy = 100 * (predictions == current_labels).float().mean().item()
            print(f'Accuracy over last {update_interval} examples: {accuracy}')

        # Optionally plot various simulation information.
        if plot:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            exc_readout_weights = network.connections['Y', 'Z'].w

            # _input = image.view(28, 28)
            # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            # square_assignments = get_square_assignments(assignments, n_sqrt)

            spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            weights2_im = plot_weights(exc_readout_weights, im=weights2_im)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')