Exemplo n.º 1
0
    def test_post_pre(self):
        # Connection test
        network = Network(dt=1.0)
        network.add_layer(Input(n=100, traces=True), name='input')
        network.add_layer(LIFNodes(n=100, traces=True), name='output')
        network.add_connection(Connection(source=network.layers['input'],
                                          target=network.layers['output'],
                                          nu=1e-2,
                                          update_rule=PostPre),
                               source='input',
                               target='output')
        network.run(
            inpts={'input': torch.bernoulli(torch.rand(250, 100)).byte()},
            time=250)

        # Conv2dConnection test
        network = Network(dt=1.0)
        network.add_layer(Input(shape=[1, 1, 10, 10], traces=True),
                          name='input')
        network.add_layer(LIFNodes(shape=[1, 32, 8, 8], traces=True),
                          name='output')
        network.add_connection(Conv2dConnection(
            source=network.layers['input'],
            target=network.layers['output'],
            kernel_size=3,
            stride=1,
            nu=1e-2,
            update_rule=PostPre),
                               source='input',
                               target='output')
        network.run(inpts={
            'input':
            torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()
        },
                    time=250)
Exemplo n.º 2
0
def create_hmax(network):
    for size in FILTER_SIZES:
        s1 = Input(shape=(FILTER_TYPES, IMAGE_SIZE, IMAGE_SIZE), traces=True)
        network.add_layer(layer=s1, name=get_s1_name(size))
        # network.add_monitor(Monitor(s1, ["s"]), get_s1_name(size))

        c1 = LIFNodes(shape=(FILTER_TYPES, IMAGE_SIZE // 2, IMAGE_SIZE // 2), thresh=-64, traces=True)
        network.add_layer(layer=c1, name=get_c1_name(size))
        # network.add_monitor(Monitor(c1, ["s", "v"]), get_c1_name(size))

        max_pool = MaxPool2dConnection(s1, c1, kernel_size=2, stride=2, decay=0.2)
        network.add_connection(max_pool, get_s1_name(size), get_c1_name(size))

    for feature in FEATURES:
        for size in FILTER_SIZES:
            s2 = LIFNodes(shape=(1, IMAGE_SIZE // 2, IMAGE_SIZE // 2), thresh=-64, traces=True)
            network.add_layer(layer=s2, name=get_s2_name(size, feature))
            # network.add_monitor(Monitor(s2, ["s", "v"]), get_s2_name(size, feature))

            conv = Conv2dConnection(network.layers[get_c1_name(size)], s2, 15, padding=7,
                                    update_rule=PostPre, wmin=0, wmax=1)

            network.add_monitor(
                Monitor(conv, ["w"]),
                "conv%d%d" % (feature, size)
            )

            network.add_connection(conv, get_c1_name(size), get_s2_name(size, feature))

            c2 = LIFNodes(shape=(1, 1, 1), thresh=-64, traces=True)
            network.add_layer(layer=c2, name=get_c2_name(size, feature))
            # network.add_monitor(Monitor(c2, ["s", "v"]), get_c2_name(size, feature))

            max_pool = MaxPool2dConnection(s2, c2, kernel_size=IMAGE_SIZE // 2, decay=0.0)
            network.add_connection(max_pool, get_s2_name(size, feature), get_c2_name(size, feature))
Exemplo n.º 3
0
    def test_weight_dependent_post_pre(self):
        # Connection test
        network = Network(dt=1.0)
        network.add_layer(Input(n=100, traces=True), name="input")
        network.add_layer(LIFNodes(n=100, traces=True), name="output")
        network.add_connection(
            Connection(
                source=network.layers["input"],
                target=network.layers["output"],
                nu=1e-2,
                update_rule=WeightDependentPostPre,
                wmin=-1,
                wmax=1,
            ),
            source="input",
            target="output",
        )
        network.run(
            inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()},
            time=250,
        )

        # Conv2dConnection test
        network = Network(dt=1.0)
        network.add_layer(Input(shape=[1, 10, 10], traces=True), name="input")
        network.add_layer(
            LIFNodes(shape=[32, 8, 8], traces=True), name="output"
        )
        network.add_connection(
            Conv2dConnection(
                source=network.layers["input"],
                target=network.layers["output"],
                kernel_size=3,
                stride=1,
                nu=1e-2,
                update_rule=WeightDependentPostPre,
                wmin=-1,
                wmax=1,
            ),
            source="input",
            target="output",
        )
        network.run(
            inputs={
                "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()
            },
            time=250,
        )
Exemplo n.º 4
0
    def test_mstdpet(self):
        # Connection test
        network = Network(dt=1.0)
        network.add_layer(Input(n=100), name="input")
        network.add_layer(LIFNodes(n=100), name="output")
        network.add_connection(
            Connection(
                source=network.layers["input"],
                target=network.layers["output"],
                nu=1e-2,
                update_rule=MSTDPET,
            ),
            source="input",
            target="output",
        )
        network.run(
            inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()},
            time=250,
            reward=1.0,
        )

        # Conv2dConnection test
        network = Network(dt=1.0)
        network.add_layer(Input(shape=[1, 10, 10]), name="input")
        network.add_layer(LIFNodes(shape=[32, 8, 8]), name="output")
        network.add_connection(
            Conv2dConnection(
                source=network.layers["input"],
                target=network.layers["output"],
                kernel_size=3,
                stride=1,
                nu=1e-2,
                update_rule=MSTDPET,
            ),
            source="input",
            target="output",
        )

        network.run(
            inputs={
                "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()
            },
            time=250,
            reward=1.0,
        )
Exemplo n.º 5
0
    def test_hebbian(self):
        # Connection test
        network = Network(dt=1.0)
        network.add_layer(Input(n=100, traces=True), name="input")
        network.add_layer(LIFNodes(n=100, traces=True), name="output")
        network.add_connection(
            Connection(
                source=network.layers["input"],
                target=network.layers["output"],
                nu=1e-2,
                update_rule=Hebbian,
            ),
            source="input",
            target="output",
        )
        network.run(
            inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()},
            time=250,
        )

        # Conv2dConnection test
        network = Network(dt=1.0)
        network.add_layer(Input(shape=[1, 10, 10], traces=True), name="input")
        network.add_layer(
            LIFNodes(shape=[32, 8, 8], traces=True), name="output"
        )
        network.add_connection(
            Conv2dConnection(
                source=network.layers["input"],
                target=network.layers["output"],
                kernel_size=3,
                stride=1,
                nu=1e-2,
                update_rule=Hebbian,
            ),
            source="input",
            target="output",
        )
        # shape is [time, batch, channels, height, width]
        network.run(
            inputs={
                "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte()
            },
            time=250,
        )
Exemplo n.º 6
0
# Build network.
network = Network()
input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

conv_layer = DiehlAndCookNodes(
    n=n_filters * conv_size * conv_size,
    shape=(n_filters, conv_size, conv_size),
    traces=True,
)

conv_conn = Conv2dConnection(
    input_layer,
    conv_layer,
    kernel_size=kernel_size,
    stride=stride,
    update_rule=PostPre,
    norm=0.4 * kernel_size**2,
    nu=[1e-4, 1e-2],
    wmax=1.0,
)

w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size,
                conv_size)
for fltr1 in range(n_filters):
    for fltr2 in range(n_filters):
        if fltr1 != fltr2:
            for i in range(conv_size):
                for j in range(conv_size):
                    w[fltr1, i, j, fltr2, i, j] = -100.0

w = w.view(n_filters * conv_size * conv_size,
def main(seed=0,
         n_train=60000,
         n_test=10000,
         kernel_size=(16, ),
         stride=(4, ),
         n_filters=25,
         padding=0,
         inhib=100,
         time=25,
         lr=1e-3,
         lr_decay=0.99,
         dt=1,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, n_train, kernel_size, stride, n_filters, padding, inhib, time,
        lr, lr_decay, dt, intensity, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, n_train, n_test, kernel_size, stride, n_filters, padding,
            inhib, time, lr, lr_decay, dt, intensity, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    input_shape = [20, 20]

    if kernel_size == input_shape:
        conv_size = [1, 1]
    else:
        conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1,
                     int((input_shape[1] - kernel_size[1]) / stride[1]) + 1)

    n_classes = 10
    n_neurons = n_filters * np.prod(conv_size)
    total_kernel_size = int(np.prod(kernel_size))
    total_conv_size = int(np.prod(conv_size))

    # Build network.
    if train:
        network = Network()
        input_layer = Input(n=400, shape=(1, 1, 20, 20), traces=True)
        conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size,
                                       shape=(1, n_filters, *conv_size),
                                       thresh=-64.0,
                                       traces=True,
                                       theta_plus=0.05 * (kernel_size[0] / 20),
                                       refrac=0)
        conv_layer2 = LIFNodes(n=n_filters * total_conv_size,
                               shape=(1, n_filters, *conv_size),
                               refrac=0)
        conv_conn = Conv2dConnection(input_layer,
                                     conv_layer,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     update_rule=WeightDependentPostPre,
                                     norm=0.05 * total_kernel_size,
                                     nu=[0, lr],
                                     wmin=0,
                                     wmax=0.25)
        conv_conn2 = Conv2dConnection(input_layer,
                                      conv_layer2,
                                      w=conv_conn.w,
                                      kernel_size=kernel_size,
                                      stride=stride,
                                      update_rule=None,
                                      wmax=0.25)

        w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1],
                                n_filters, conv_size[0], conv_size[1])
        for f in range(n_filters):
            for f2 in range(n_filters):
                if f != f2:
                    w[f, :, :f2, :, :] = 0

        w = w.view(n_filters * conv_size[0] * conv_size[1],
                   n_filters * conv_size[0] * conv_size[1])
        recurrent_conn = Connection(conv_layer, conv_layer, w=w)

        network.add_layer(input_layer, name='X')
        network.add_layer(conv_layer, name='Y')
        network.add_layer(conv_layer2, name='Y_')
        network.add_connection(conv_conn, source='X', target='Y')
        network.add_connection(conv_conn2, source='X', target='Y_')
        network.add_connection(recurrent_conn, source='Y', target='Y')

        # Voltage recording for excitatory and inhibitory layers.
        voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
        network.add_monitor(voltage_monitor, name='output_voltage')
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity
    images = images[:, 4:-4, 4:-4].contiguous()

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_examples, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs',
                                          max_iter=1000,
                                          multi_class='multinomial')
        logreg_model.coef_ = np.zeros([n_classes, n_neurons])
        logreg_model.intercept_ = np.zeros(n_classes)
        logreg_model.classes_ = np.arange(n_classes)
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        logreg_coef, logreg_intercept = torch.load(open(path, 'rb'))
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs',
                                          max_iter=1000,
                                          multi_class='multinomial')
        logreg_model.coef_ = logreg_coef
        logreg_model.intercept_ = logreg_intercept
        logreg_model.classes_ = np.arange(n_classes)

    # Sequence of accuracy estimates.
    curves = {'logreg': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_ims = None
    inpt_axes = None
    spike_ims = None
    spike_axes = None
    weights_im = None

    plot_update_interval = 100

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print('Progress: %d / %d (%.4f seconds)' %
                  (i, n_examples, t() - start))
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
                current_record = full_spike_record[-update_interval:]
            else:
                current_labels = labels[i % len(labels) - update_interval:i %
                                        len(labels)]
                current_record = full_spike_record[i % len(labels) -
                                                   update_interval:i %
                                                   len(labels)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          full_spike_record=current_record,
                                          logreg=logreg_model)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((logreg_model.coef_, logreg_model.intercept_),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Refit logistic regression model.
                logreg_model = logreg_fit(full_spike_record[:i], labels[:i],
                                          logreg_model)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = bernoulli(datum=image, time=time, dt=dt,
                           max_prob=1).unsqueeze(1).unsqueeze(1)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        network.connections['X', 'Y_'].w = network.connections['X', 'Y'].w

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y_'].get('s').view(
            time, -1)
        full_spike_record[i] = spikes['Y_'].get('s').view(time, -1).sum(0)

        # Optionally plot various simulation information.
        if plot and i % plot_update_interval == 0:
            _input = inpts['X'].view(time, 400).sum(0).view(20, 20)
            w = network.connections['X', 'Y'].w

            _spikes = {
                'X': spikes['X'].get('s').view(400, time),
                'Y': spikes['Y'].get('s').view(n_filters * total_conv_size,
                                               time),
                'Y_': spikes['Y_'].get('s').view(n_filters * total_conv_size,
                                                 time)
            }

            inpt_axes, inpt_ims = plot_input(image.view(20, 20),
                                             _input,
                                             label=labels[i % len(labels)],
                                             ims=inpt_ims,
                                             axes=inpt_axes)
            spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_conv2d_weights(
                w, im=weights_im, wmax=network.connections['X', 'Y'].wmax)

            plt.pause(1e-2)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
        current_record = full_spike_record[-update_interval:]
    else:
        current_labels = labels[i % len(labels) - update_interval:i %
                                len(labels)]
        current_record = full_spike_record[i % len(labels) -
                                           update_interval:i % len(labels)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  full_spike_record=current_record,
                                  logreg=logreg_model)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path,
                                '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((logreg_model.coef_, logreg_model.intercept_),
                       open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [np.mean(curves['logreg']), np.std(curves['logreg'])]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                columns = [
                    'seed', 'n_train', 'kernel_size', 'stride', 'n_filters',
                    'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt',
                    'intensity', 'update_interval', 'mean_logreg', 'std_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)
            else:
                columns = [
                    'seed', 'n_train', 'n_test', 'kernel_size', 'stride',
                    'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay',
                    'dt', 'intensity', 'update_interval', 'mean_logreg',
                    'std_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))
Exemplo n.º 8
0
# Build network.
network = Network()
input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

conv_layer = DiehlAndCookNodes(
    n=n_filters * conv_size * conv_size,
    shape=(n_filters, conv_size, conv_size),
    traces=True,
)

conv_conn = Conv2dConnection(
    input_layer,
    conv_layer,
    kernel_size=kernel_size,
    stride=stride,
    update_rule=PostPre,
    norm=0.4 * kernel_size**2,
    nu=[0, lr],
    reduction=torch.mean,
    wmax=1.0,
)

w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size,
                conv_size)
for fltr1 in range(n_filters):
    for fltr2 in range(n_filters):
        if fltr1 != fltr2:
            for i in range(conv_size):
                for j in range(conv_size):
                    w[fltr1, i, j, fltr2, i, j] = -100.0
Exemplo n.º 9
0
def main(args):
    if args.gpu:
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    conv_size = int(
        (28 - args.kernel_size + 2 * args.padding) / args.stride) + 1

    # Build network.
    network = Network()
    input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

    conv_layer = DiehlAndCookNodes(
        n=args.n_filters * conv_size * conv_size,
        shape=(args.n_filters, conv_size, conv_size),
        traces=True,
    )

    conv_conn = Conv2dConnection(
        input_layer,
        conv_layer,
        kernel_size=args.kernel_size,
        stride=args.stride,
        update_rule=PostPre,
        norm=0.4 * args.kernel_size**2,
        nu=[0, args.lr],
        reduction=max_without_indices,
        wmax=1.0,
    )

    w = torch.zeros(args.n_filters, conv_size, conv_size, args.n_filters,
                    conv_size, conv_size)
    for fltr1 in range(args.n_filters):
        for fltr2 in range(args.n_filters):
            if fltr1 != fltr2:
                for i in range(conv_size):
                    for j in range(conv_size):
                        w[fltr1, i, j, fltr2, i, j] = -100.0

    w = w.view(args.n_filters * conv_size * conv_size,
               args.n_filters * conv_size * conv_size)
    recurrent_conn = Connection(conv_layer, conv_layer, w=w)

    network.add_layer(input_layer, name="X")
    network.add_layer(conv_layer, name="Y")
    network.add_connection(conv_conn, source="X", target="Y")
    network.add_connection(recurrent_conn, source="Y", target="Y")

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers["Y"], ["v"], time=args.time)
    network.add_monitor(voltage_monitor, name="output_voltage")

    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    train_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=["s"],
                                time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    voltages = {}
    for layer in set(network.layers) - {"X"}:
        voltages[layer] = Monitor(network.layers[layer],
                                  state_vars=["v"],
                                  time=args.time)
        network.add_monitor(voltages[layer], name="%s_voltages" % layer)

    # Train the network.
    print("Begin training.\n")
    start = time()

    weights_im = None

    for epoch in range(args.n_epochs):
        if epoch % args.progress_interval == 0:
            print("Progress: %d / %d (%.4f seconds)" %
                  (epoch, args.n_epochs, time() - start))
            start = time()

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(tqdm(train_dataloader)):
            # Get next input sample.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input.
            network.run(inpts=inpts, time=args.time, input_time_dim=0)

            # Decay learning rate.
            network.connections["X", "Y"].nu[1] *= 0.99

            # Optionally plot various simulation information.
            if args.plot:
                weights = conv_conn.w
                weights_im = plot_conv2d_weights(weights, im=weights_im)

                plt.pause(1e-8)

            network.reset_()  # Reset state variables.

    print("Progress: %d / %d (%.4f seconds)\n" %
          (args.n_epochs, args.n_epochs, time() - start))
    print("Training complete.\n")
Exemplo n.º 10
0
input_layer = Input(n=784, shape=(1, 1, 28, 28), traces=True)
conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size,
                               shape=(1, n_filters, *conv_size),
                               thresh=-64.0,
                               traces=True,
                               theta_plus=0.05,
                               refrac=0)
conv_layer_ = DiehlAndCookNodes(n=n_filters * total_conv_size,
                                shape=(1, n_filters, *conv_size),
                                refrac=0,
                                traces=True,
                                theta_decay=5e-1)
conv_conn = Conv2dConnection(input_layer,
                             conv_layer,
                             kernel_size=kernel_size,
                             stride=stride,
                             update_rule=PostPre,
                             norm=1.0 * int(np.sqrt(total_kernel_size)),
                             nu=(0, 1e-2),
                             wmax=2.0)
conv_conn_ = Conv2dConnection(input_layer,
                              conv_layer_,
                              w=conv_conn.w,
                              kernel_size=kernel_size,
                              stride=stride,
                              update_rule=None,
                              nu=(0, 1e-2),
                              wmax=2.0)
conv_layer2 = DiehlAndCookNodes(n=n_filters * total_conv_size2,
                                shape=(1, n_filters, *conv_size2),
                                thresh=-64.0,
                                traces=True,
Exemplo n.º 11
0
def prepare_network():
    global net

    net = Network()

    for g_size in G_SIZES:
        s1_g_size = Input(shape=(len(G_THETAS), IMG_SHAPE[0], IMG_SHAPE[1],), traces=True)
        net.add_layer(layer=s1_g_size, name=s1_name(g_size))

        c1_g_size = LIFNodes(shape=(len(G_THETAS), IMG_SHAPE[0] // 2, IMG_SHAPE[1] // 2,), thresh=-64, traces=True)
        net.add_layer(layer=c1_g_size, name=c1_name(g_size))

        max_pool_con = MaxPool2dConnection(s1_g_size, c1_g_size, kernel_size=2, stride=2, decay=0.0)
        net.add_connection(max_pool_con, s1_name(g_size), c1_name(g_size))

    for f_idx in range(N_SIZE_FEATURES):
        for g_size in G_SIZES:
            s2_nodes = LIFNodes(shape=(1, IMG_SHAPE[0] // 2, IMG_SHAPE[1] // 2,), traces=True, tc_decay=50.0, thresh=-55, trace_scale=0.2)
            net.add_layer(layer=s2_nodes, name=s2_name(f_idx, g_size))

            conv_con = Conv2dConnection(net.layers[c1_name(g_size)], s2_nodes, 5, padding=2,  nu=[0.0006, 0.008], update_rule=PostPre, wmin=0, wmax=1)
            net.add_connection(conv_con, c1_name(g_size), s2_name(f_idx, g_size))

            c2_nodes = LIFNodes(shape=(1, IMG_SHAPE[0] // 4, IMG_SHAPE[1] // 4,), thresh=-64, traces=True)
            net.add_layer(layer=c2_nodes, name=c2_name(f_idx, g_size))

            max_pool_con = MaxPool2dConnection(s2_nodes, c2_nodes, kernel_size=2, stride=2, decay=0.0)
            net.add_connection(max_pool_con, s2_name(f_idx, g_size), c2_name(f_idx, g_size))

    d1 = LIFNodes(n=DEEP_LAYERS_N, traces=True)
    net.add_layer(layer=d1, name=d1_name())
    for f_idx in range(N_SIZE_FEATURES):
        for g_size in G_SIZES:
            src_layer = net.layers[c2_name(f_idx, g_size)]
            conn = Connection(
                source=src_layer,
                target=d1,
                w=0.05 + 0.1 * torch.randn(src_layer.n, d1.n),
                update_rule=PostPre
            )
            net.add_connection(conn, c2_name(f_idx, g_size), d1_name())

    d2 = LIFNodes(n=DEEP_LAYERS_N, traces=True)
    net.add_layer(layer=d2, name=d2_name())
    
    d1_d2_conn = Connection(
        source=d1,
        target=d2,
        w=0.05 + 0.1 * torch.randn(d1.n, d2.n),
        update_rule=PostPre
    )
    net.add_connection(d1_d2_conn, d1_name(), d2_name())

    r = LIFNodes(n=len(TARGETS), traces=True)
    net.add_layer(layer=r, name="R")

    d2_r_conn = Connection(
        source=d2,
        target=r,
        w=0.05 + 0.05 * torch.randn(d2.n, r.n),
        update_rule=PostPre
    )
    net.add_connection(d2_r_conn, d2_name(), r_name())

    r_rec = Connection(
        source=r,
        target=r,
        w=0.5 * (torch.eye(r.n) - 1),
        decay=0,
    )
    net.add_connection(r_rec, r_name(), r_name())

    net.add_monitor(
        Monitor(net.layers[r_name()], ["s"]),
        "result"
    )
Exemplo n.º 12
0
# Build network.
network = Network()
input_layer = Input(n=32 * 32 * 3, shape=(1, 3, 32, 32), traces=True)
conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size,
                               shape=(1, n_filters, *conv_size),
                               thresh=-64.0,
                               traces=True,
                               theta_plus=0.05,
                               refrac=0)
conv_layer2 = DiehlAndCookNodes(n=n_filters * total_conv_size,
                                shape=(1, n_filters, *conv_size),
                                refrac=0)
conv_conn = Conv2dConnection(input_layer,
                             conv_layer,
                             kernel_size=kernel_size,
                             stride=stride,
                             update_rule=Hebbian,
                             norm=0.5 * int(np.sqrt(total_kernel_size)),
                             nu=(1e-3, 1e-3),
                             wmax=2.0)
conv_conn2 = Conv2dConnection(input_layer,
                              conv_layer2,
                              w=conv_conn.w,
                              kernel_size=kernel_size,
                              stride=stride,
                              update_rule=None,
                              nu=(0, 1e-3),
                              wmax=2.0)

w = torch.ones(1, n_filters, conv_size[0], conv_size[1], 1, n_filters,
               conv_size[0], conv_size[1])
for f in range(n_filters):
Exemplo n.º 13
0
def main(args):
    # Random seed.
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers.
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    # Build network.
    network = bindsnet.network.Network(dt=args.dt, batch_size=args.batch_size)

    # Layers.
    input_layer = Input(shape=(1, 28, 28), traces=True)
    conv1_layer = LIFNodes(shape=(20, 24, 24), traces=True)
    pool1_layer = PassThroughNodes(shape=(20, 12, 12), traces=True)
    conv2_layer = LIFNodes(shape=(50, 8, 8), traces=True)
    pool2_layer = PassThroughNodes(shape=(50, 4, 4), traces=True)
    dense_layer = LIFNodes(shape=(200, ), traces=True)
    output_layer = LIFNodes(shape=(10, ), traces=True)

    network.add_layer(input_layer, name="I")
    network.add_layer(conv1_layer, name="C1")
    network.add_layer(pool1_layer, name="P1")
    network.add_layer(conv2_layer, name="C2")
    network.add_layer(pool2_layer, name="P2")
    network.add_layer(dense_layer, name="D")
    network.add_layer(output_layer, name="O")

    # Connections.
    conv1_connection = Conv2dConnection(
        source=input_layer,
        target=conv1_layer,
        update_rule=WeightDependentPost,
        nu=(0.0, args.nu),
        kernel_size=5,
        stride=1,
        wmin=-1.0,
        wmax=1.0,
    )
    pool1_connection = SpatialPooling2dConnection(source=conv1_layer,
                                                  target=pool1_layer,
                                                  kernel_size=2,
                                                  stride=2)
    conv2_connection = Conv2dConnection(
        source=pool1_layer,
        target=conv2_layer,
        update_rule=WeightDependentPost,
        nu=(0.0, args.nu),
        kernel_size=5,
        stride=1,
        wmin=-1.0,
        wmax=1.0,
    )
    pool2_connection = SpatialPooling2dConnection(source=conv2_layer,
                                                  target=pool2_layer,
                                                  kernel_size=2,
                                                  stride=2)
    dense_connection = Connection(
        source=pool2_layer,
        target=dense_layer,
        update_rule=WeightDependentPost,
        nu=(0.0, args.nu),
        wmin=-1.0,
        wmax=1.0,
    )
    output_connection = Connection(
        source=dense_layer,
        target=output_layer,
        update_rule=WeightDependentPost,
        nu=(0.0, args.nu),
        wmin=-1.0,
        wmax=1.0,
    )

    network.add_connection(connection=conv1_connection,
                           source="I",
                           target="C1")
    network.add_connection(connection=pool1_connection,
                           source="C1",
                           target="P1")
    network.add_connection(connection=conv2_connection,
                           source="P1",
                           target="C2")
    network.add_connection(connection=pool2_connection,
                           source="C2",
                           target="P2")
    network.add_connection(connection=dense_connection,
                           source="P2",
                           target="D")
    network.add_connection(connection=output_connection,
                           source="D",
                           target="O")

    # Monitors.
    for name, layer in network.layers.items():
        monitor = Monitor(obj=layer, state_vars=("s", ), time=args.time)
        network.add_monitor(monitor=monitor, name=name)

    # Directs network to GPU.
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(bindsnet.ROOT_DIR, "data", "MNIST"),
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    spike_ims = None
    spike_axes = None
    conv1_weights_im = None
    conv2_weights_im = None
    dense_weights_im = None
    output_weights_im = None

    for epoch in range(args.n_epochs):
        # Create a dataloader to iterate over dataset.
        dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.n_workers,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(tqdm(dataloader)):
            # Prep next input batch.
            inpts = {"I": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input.
            network.run(inpts=inpts, time=args.time)

            # Plot simulation data.
            if args.plot:
                spikes = {}
                for name, monitor in network.monitors.items():
                    spikes[name] = monitor.get("s")[:, 0].view(args.time, -1)

                spike_ims, spike_axes = plot_spikes(spikes,
                                                    ims=spike_ims,
                                                    axes=spike_axes)

                conv1_weights_im = plot_conv2d_weights(conv1_connection.w,
                                                       im=conv1_weights_im,
                                                       wmin=-1.0,
                                                       wmax=1.0)
                conv2_weights_im = plot_conv2d_weights(conv2_connection.w,
                                                       im=conv2_weights_im,
                                                       wmin=-1.0,
                                                       wmax=1.0)
                dense_weights_im = plot_weights(dense_connection.w,
                                                im=dense_weights_im,
                                                wmin=-1.0,
                                                wmax=1.0)
                output_weights_im = plot_weights(output_connection.w,
                                                 im=output_weights_im,
                                                 wmin=-1.0,
                                                 wmax=1.0)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_()
Exemplo n.º 14
0
def main(seed=0,
         n_train=60000,
         n_test=10000,
         kernel_size=16,
         stride=4,
         n_filters=25,
         padding=0,
         inhib=500,
         lr=0.01,
         lr_decay=0.99,
         time=50,
         dt=1,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         train=True,
         plot=False,
         gpu=False):

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    if not train:
        update_interval = n_test

    if kernel_size == 32:
        conv_size = 1
    else:
        conv_size = int((32 - kernel_size + 2 * padding) / stride) + 1

    per_class = int((n_filters * conv_size * conv_size) / 10)

    # Build network.
    network = Network()
    input_layer = Input(n=1024, shape=(1, 1, 32, 32), traces=True)

    conv_layer = DiehlAndCookNodes(n=n_filters * conv_size * conv_size,
                                   shape=(1, n_filters, conv_size, conv_size),
                                   traces=True)

    conv_conn = Conv2dConnection(input_layer,
                                 conv_layer,
                                 kernel_size=kernel_size,
                                 stride=stride,
                                 update_rule=PostPre,
                                 norm=0.4 * kernel_size**2,
                                 nu=[0, lr],
                                 wmin=0,
                                 wmax=1)

    w = -inhib * torch.ones(n_filters, conv_size, conv_size, n_filters,
                            conv_size, conv_size)
    for f in range(n_filters):
        for i in range(conv_size):
            for j in range(conv_size):
                w[f, i, j, f, i, j] = 0

    w = w.view(n_filters * conv_size**2, n_filters * conv_size**2)
    recurrent_conn = Connection(conv_layer, conv_layer, w=w)

    network.add_layer(input_layer, name='X')
    network.add_layer(conv_layer, name='Y')
    network.add_connection(conv_conn, source='X', target='Y')
    network.add_connection(recurrent_conn, source='Y', target='Y')

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load CIFAR-10 data.
    dataset = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'),
                      download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity
    images = images.mean(-1)

    # Lazily encode data as Poisson spike trains.
    data_loader = poisson_loader(data=images, time=time, dt=dt)

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    voltages = {}
    for layer in set(network.layers) - {'X'}:
        voltages[layer] = Monitor(network.layers[layer],
                                  state_vars=['v'],
                                  time=time)
        network.add_monitor(voltages[layer], name='%s_voltages' % layer)

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    voltage_ims = None
    voltage_axes = None

    # Train the network.
    print('Begin training.\n')
    start = t()

    for i in range(n_train):
        if i % progress_interval == 0:
            print('Progress: %d / %d (%.4f seconds)' %
                  (i, n_train, t() - start))
            start = t()

            if train and i > 0:
                network.connections['X', 'Y'].nu[1] *= lr_decay

        # Get next input sample.
        sample = next(data_loader).unsqueeze(1).unsqueeze(1)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        # Optionally plot various simulation information.
        if plot:
            # inpt = inpts['X'].view(time, 1024).sum(0).view(32, 32)

            weights1 = conv_conn.w
            _spikes = {
                'X': spikes['X'].get('s').view(32**2, time),
                'Y': spikes['Y'].get('s').view(n_filters * conv_size**2, time)
            }
            _voltages = {
                'Y': voltages['Y'].get('v').view(n_filters * conv_size**2,
                                                 time)
            }

            # inpt_axes, inpt_ims = plot_input(
            #     images[i].view(32, 32), inpt, label=labels[i], axes=inpt_axes, ims=inpt_ims
            # )
            # voltage_ims, voltage_axes = plot_voltages(_voltages, ims=voltage_ims, axes=voltage_axes)

            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_conv2d_weights(weights1, im=weights_im)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print('Progress: %d / %d (%.4f seconds)\n' %
          (n_train, n_train, t() - start))
    print('Training complete.\n')