def test_post_pre(self): # Connection test network = Network(dt=1.0) network.add_layer(Input(n=100, traces=True), name='input') network.add_layer(LIFNodes(n=100, traces=True), name='output') network.add_connection(Connection(source=network.layers['input'], target=network.layers['output'], nu=1e-2, update_rule=PostPre), source='input', target='output') network.run( inpts={'input': torch.bernoulli(torch.rand(250, 100)).byte()}, time=250) # Conv2dConnection test network = Network(dt=1.0) network.add_layer(Input(shape=[1, 1, 10, 10], traces=True), name='input') network.add_layer(LIFNodes(shape=[1, 32, 8, 8], traces=True), name='output') network.add_connection(Conv2dConnection( source=network.layers['input'], target=network.layers['output'], kernel_size=3, stride=1, nu=1e-2, update_rule=PostPre), source='input', target='output') network.run(inpts={ 'input': torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() }, time=250)
def create_hmax(network): for size in FILTER_SIZES: s1 = Input(shape=(FILTER_TYPES, IMAGE_SIZE, IMAGE_SIZE), traces=True) network.add_layer(layer=s1, name=get_s1_name(size)) # network.add_monitor(Monitor(s1, ["s"]), get_s1_name(size)) c1 = LIFNodes(shape=(FILTER_TYPES, IMAGE_SIZE // 2, IMAGE_SIZE // 2), thresh=-64, traces=True) network.add_layer(layer=c1, name=get_c1_name(size)) # network.add_monitor(Monitor(c1, ["s", "v"]), get_c1_name(size)) max_pool = MaxPool2dConnection(s1, c1, kernel_size=2, stride=2, decay=0.2) network.add_connection(max_pool, get_s1_name(size), get_c1_name(size)) for feature in FEATURES: for size in FILTER_SIZES: s2 = LIFNodes(shape=(1, IMAGE_SIZE // 2, IMAGE_SIZE // 2), thresh=-64, traces=True) network.add_layer(layer=s2, name=get_s2_name(size, feature)) # network.add_monitor(Monitor(s2, ["s", "v"]), get_s2_name(size, feature)) conv = Conv2dConnection(network.layers[get_c1_name(size)], s2, 15, padding=7, update_rule=PostPre, wmin=0, wmax=1) network.add_monitor( Monitor(conv, ["w"]), "conv%d%d" % (feature, size) ) network.add_connection(conv, get_c1_name(size), get_s2_name(size, feature)) c2 = LIFNodes(shape=(1, 1, 1), thresh=-64, traces=True) network.add_layer(layer=c2, name=get_c2_name(size, feature)) # network.add_monitor(Monitor(c2, ["s", "v"]), get_c2_name(size, feature)) max_pool = MaxPool2dConnection(s2, c2, kernel_size=IMAGE_SIZE // 2, decay=0.0) network.add_connection(max_pool, get_s2_name(size, feature), get_c2_name(size, feature))
def test_weight_dependent_post_pre(self): # Connection test network = Network(dt=1.0) network.add_layer(Input(n=100, traces=True), name="input") network.add_layer(LIFNodes(n=100, traces=True), name="output") network.add_connection( Connection( source=network.layers["input"], target=network.layers["output"], nu=1e-2, update_rule=WeightDependentPostPre, wmin=-1, wmax=1, ), source="input", target="output", ) network.run( inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, time=250, ) # Conv2dConnection test network = Network(dt=1.0) network.add_layer(Input(shape=[1, 10, 10], traces=True), name="input") network.add_layer( LIFNodes(shape=[32, 8, 8], traces=True), name="output" ) network.add_connection( Conv2dConnection( source=network.layers["input"], target=network.layers["output"], kernel_size=3, stride=1, nu=1e-2, update_rule=WeightDependentPostPre, wmin=-1, wmax=1, ), source="input", target="output", ) network.run( inputs={ "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() }, time=250, )
def test_mstdpet(self): # Connection test network = Network(dt=1.0) network.add_layer(Input(n=100), name="input") network.add_layer(LIFNodes(n=100), name="output") network.add_connection( Connection( source=network.layers["input"], target=network.layers["output"], nu=1e-2, update_rule=MSTDPET, ), source="input", target="output", ) network.run( inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, time=250, reward=1.0, ) # Conv2dConnection test network = Network(dt=1.0) network.add_layer(Input(shape=[1, 10, 10]), name="input") network.add_layer(LIFNodes(shape=[32, 8, 8]), name="output") network.add_connection( Conv2dConnection( source=network.layers["input"], target=network.layers["output"], kernel_size=3, stride=1, nu=1e-2, update_rule=MSTDPET, ), source="input", target="output", ) network.run( inputs={ "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() }, time=250, reward=1.0, )
def test_hebbian(self): # Connection test network = Network(dt=1.0) network.add_layer(Input(n=100, traces=True), name="input") network.add_layer(LIFNodes(n=100, traces=True), name="output") network.add_connection( Connection( source=network.layers["input"], target=network.layers["output"], nu=1e-2, update_rule=Hebbian, ), source="input", target="output", ) network.run( inputs={"input": torch.bernoulli(torch.rand(250, 100)).byte()}, time=250, ) # Conv2dConnection test network = Network(dt=1.0) network.add_layer(Input(shape=[1, 10, 10], traces=True), name="input") network.add_layer( LIFNodes(shape=[32, 8, 8], traces=True), name="output" ) network.add_connection( Conv2dConnection( source=network.layers["input"], target=network.layers["output"], kernel_size=3, stride=1, nu=1e-2, update_rule=Hebbian, ), source="input", target="output", ) # shape is [time, batch, channels, height, width] network.run( inputs={ "input": torch.bernoulli(torch.rand(250, 1, 1, 10, 10)).byte() }, time=250, )
# Build network. network = Network() input_layer = Input(n=784, shape=(1, 28, 28), traces=True) conv_layer = DiehlAndCookNodes( n=n_filters * conv_size * conv_size, shape=(n_filters, conv_size, conv_size), traces=True, ) conv_conn = Conv2dConnection( input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=0.4 * kernel_size**2, nu=[1e-4, 1e-2], wmax=1.0, ) w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: for i in range(conv_size): for j in range(conv_size): w[fltr1, i, j, fltr2, i, j] = -100.0 w = w.view(n_filters * conv_size * conv_size,
def main(seed=0, n_train=60000, n_test=10000, kernel_size=(16, ), stride=(4, ), n_filters=25, padding=0, inhib=100, time=25, lr=1e-3, lr_decay=0.99, dt=1, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_train, kernel_size, stride, n_filters, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, n_train, n_test, kernel_size, stride, n_filters, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test input_shape = [20, 20] if kernel_size == input_shape: conv_size = [1, 1] else: conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1, int((input_shape[1] - kernel_size[1]) / stride[1]) + 1) n_classes = 10 n_neurons = n_filters * np.prod(conv_size) total_kernel_size = int(np.prod(kernel_size)) total_conv_size = int(np.prod(conv_size)) # Build network. if train: network = Network() input_layer = Input(n=400, shape=(1, 1, 20, 20), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), thresh=-64.0, traces=True, theta_plus=0.05 * (kernel_size[0] / 20), refrac=0) conv_layer2 = LIFNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), refrac=0) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=WeightDependentPostPre, norm=0.05 * total_kernel_size, nu=[0, lr], wmin=0, wmax=0.25) conv_conn2 = Conv2dConnection(input_layer, conv_layer2, w=conv_conn.w, kernel_size=kernel_size, stride=stride, update_rule=None, wmax=0.25) w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1]) for f in range(n_filters): for f2 in range(n_filters): if f != f2: w[f, :, :f2, :, :] = 0 w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_layer(conv_layer2, name='Y_') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(conv_conn2, source='X', target='Y_') network.add_connection(recurrent_conn, source='Y', target='Y') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images[:, 4:-4, 4:-4].contiguous() # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_examples, n_neurons) # Neuron assignments and spike proportions. if train: logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs', max_iter=1000, multi_class='multinomial') logreg_model.coef_ = np.zeros([n_classes, n_neurons]) logreg_model.intercept_ = np.zeros(n_classes) logreg_model.classes_ = np.arange(n_classes) else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') logreg_coef, logreg_intercept = torch.load(open(path, 'rb')) logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs', max_iter=1000, multi_class='multinomial') logreg_model.coef_ = logreg_coef logreg_model.intercept_ = logreg_intercept logreg_model.classes_ = np.arange(n_classes) # Sequence of accuracy estimates. curves = {'logreg': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_ims = None inpt_axes = None spike_ims = None spike_axes = None weights_im = None plot_update_interval = 100 start = t() for i in range(n_examples): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_examples, t() - start)) start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] current_record = full_spike_record[-update_interval:] else: current_labels = labels[i % len(labels) - update_interval:i % len(labels)] current_record = full_spike_record[i % len(labels) - update_interval:i % len(labels)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, full_spike_record=current_record, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Refit logistic regression model. logreg_model = logreg_fit(full_spike_record[:i], labels[:i], logreg_model) print() # Get next input sample. image = images[i % len(images)] sample = bernoulli(datum=image, time=time, dt=dt, max_prob=1).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) network.connections['X', 'Y_'].w = network.connections['X', 'Y'].w # Add to spikes recording. spike_record[i % update_interval] = spikes['Y_'].get('s').view( time, -1) full_spike_record[i] = spikes['Y_'].get('s').view(time, -1).sum(0) # Optionally plot various simulation information. if plot and i % plot_update_interval == 0: _input = inpts['X'].view(time, 400).sum(0).view(20, 20) w = network.connections['X', 'Y'].w _spikes = { 'X': spikes['X'].get('s').view(400, time), 'Y': spikes['Y'].get('s').view(n_filters * total_conv_size, time), 'Y_': spikes['Y_'].get('s').view(n_filters * total_conv_size, time) } inpt_axes, inpt_ims = plot_input(image.view(20, 20), _input, label=labels[i % len(labels)], ims=inpt_ims, axes=inpt_axes) spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights( w, im=weights_im, wmax=network.connections['X', 'Y'].wmax) plt.pause(1e-2) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] current_record = full_spike_record[-update_interval:] else: current_labels = labels[i % len(labels) - update_interval:i % len(labels)] current_record = full_spike_record[i % len(labels) - update_interval:i % len(labels)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, full_spike_record=current_record, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params to_write = [str(x) for x in to_write] f = '_'.join(to_write) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [np.mean(curves['logreg']), np.std(curves['logreg'])] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: columns = [ 'seed', 'n_train', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_logreg', 'std_logreg' ] header = ','.join(columns) + '\n' f.write(header) else: columns = [ 'seed', 'n_train', 'n_test', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_logreg', 'std_logreg' ] header = ','.join(columns) + '\n' f.write(header) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
# Build network. network = Network() input_layer = Input(n=784, shape=(1, 28, 28), traces=True) conv_layer = DiehlAndCookNodes( n=n_filters * conv_size * conv_size, shape=(n_filters, conv_size, conv_size), traces=True, ) conv_conn = Conv2dConnection( input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=0.4 * kernel_size**2, nu=[0, lr], reduction=torch.mean, wmax=1.0, ) w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: for i in range(conv_size): for j in range(conv_size): w[fltr1, i, j, fltr2, i, j] = -100.0
def main(args): if args.gpu: torch.cuda.manual_seed_all(args.seed) else: torch.manual_seed(args.seed) conv_size = int( (28 - args.kernel_size + 2 * args.padding) / args.stride) + 1 # Build network. network = Network() input_layer = Input(n=784, shape=(1, 28, 28), traces=True) conv_layer = DiehlAndCookNodes( n=args.n_filters * conv_size * conv_size, shape=(args.n_filters, conv_size, conv_size), traces=True, ) conv_conn = Conv2dConnection( input_layer, conv_layer, kernel_size=args.kernel_size, stride=args.stride, update_rule=PostPre, norm=0.4 * args.kernel_size**2, nu=[0, args.lr], reduction=max_without_indices, wmax=1.0, ) w = torch.zeros(args.n_filters, conv_size, conv_size, args.n_filters, conv_size, conv_size) for fltr1 in range(args.n_filters): for fltr2 in range(args.n_filters): if fltr1 != fltr2: for i in range(conv_size): for j in range(conv_size): w[fltr1, i, j, fltr2, i, j] = -100.0 w = w.view(args.n_filters * conv_size * conv_size, args.n_filters * conv_size * conv_size) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name="X") network.add_layer(conv_layer, name="Y") network.add_connection(conv_conn, source="X", target="Y") network.add_connection(recurrent_conn, source="Y", target="Y") # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers["Y"], ["v"], time=args.time) network.add_monitor(voltage_monitor, name="output_voltage") if args.gpu: network.to("cuda") # Load MNIST data. train_dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=args.time) network.add_monitor(spikes[layer], name="%s_spikes" % layer) voltages = {} for layer in set(network.layers) - {"X"}: voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=args.time) network.add_monitor(voltages[layer], name="%s_voltages" % layer) # Train the network. print("Begin training.\n") start = time() weights_im = None for epoch in range(args.n_epochs): if epoch % args.progress_interval == 0: print("Progress: %d / %d (%.4f seconds)" % (epoch, args.n_epochs, time() - start)) start = time() train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=args.gpu, ) for step, batch in enumerate(tqdm(train_dataloader)): # Get next input sample. inpts = {"X": batch["encoded_image"]} if args.gpu: inpts = {k: v.cuda() for k, v in inpts.items()} # Run the network on the input. network.run(inpts=inpts, time=args.time, input_time_dim=0) # Decay learning rate. network.connections["X", "Y"].nu[1] *= 0.99 # Optionally plot various simulation information. if args.plot: weights = conv_conn.w weights_im = plot_conv2d_weights(weights, im=weights_im) plt.pause(1e-8) network.reset_() # Reset state variables. print("Progress: %d / %d (%.4f seconds)\n" % (args.n_epochs, args.n_epochs, time() - start)) print("Training complete.\n")
input_layer = Input(n=784, shape=(1, 1, 28, 28), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), thresh=-64.0, traces=True, theta_plus=0.05, refrac=0) conv_layer_ = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), refrac=0, traces=True, theta_decay=5e-1) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=1.0 * int(np.sqrt(total_kernel_size)), nu=(0, 1e-2), wmax=2.0) conv_conn_ = Conv2dConnection(input_layer, conv_layer_, w=conv_conn.w, kernel_size=kernel_size, stride=stride, update_rule=None, nu=(0, 1e-2), wmax=2.0) conv_layer2 = DiehlAndCookNodes(n=n_filters * total_conv_size2, shape=(1, n_filters, *conv_size2), thresh=-64.0, traces=True,
def prepare_network(): global net net = Network() for g_size in G_SIZES: s1_g_size = Input(shape=(len(G_THETAS), IMG_SHAPE[0], IMG_SHAPE[1],), traces=True) net.add_layer(layer=s1_g_size, name=s1_name(g_size)) c1_g_size = LIFNodes(shape=(len(G_THETAS), IMG_SHAPE[0] // 2, IMG_SHAPE[1] // 2,), thresh=-64, traces=True) net.add_layer(layer=c1_g_size, name=c1_name(g_size)) max_pool_con = MaxPool2dConnection(s1_g_size, c1_g_size, kernel_size=2, stride=2, decay=0.0) net.add_connection(max_pool_con, s1_name(g_size), c1_name(g_size)) for f_idx in range(N_SIZE_FEATURES): for g_size in G_SIZES: s2_nodes = LIFNodes(shape=(1, IMG_SHAPE[0] // 2, IMG_SHAPE[1] // 2,), traces=True, tc_decay=50.0, thresh=-55, trace_scale=0.2) net.add_layer(layer=s2_nodes, name=s2_name(f_idx, g_size)) conv_con = Conv2dConnection(net.layers[c1_name(g_size)], s2_nodes, 5, padding=2, nu=[0.0006, 0.008], update_rule=PostPre, wmin=0, wmax=1) net.add_connection(conv_con, c1_name(g_size), s2_name(f_idx, g_size)) c2_nodes = LIFNodes(shape=(1, IMG_SHAPE[0] // 4, IMG_SHAPE[1] // 4,), thresh=-64, traces=True) net.add_layer(layer=c2_nodes, name=c2_name(f_idx, g_size)) max_pool_con = MaxPool2dConnection(s2_nodes, c2_nodes, kernel_size=2, stride=2, decay=0.0) net.add_connection(max_pool_con, s2_name(f_idx, g_size), c2_name(f_idx, g_size)) d1 = LIFNodes(n=DEEP_LAYERS_N, traces=True) net.add_layer(layer=d1, name=d1_name()) for f_idx in range(N_SIZE_FEATURES): for g_size in G_SIZES: src_layer = net.layers[c2_name(f_idx, g_size)] conn = Connection( source=src_layer, target=d1, w=0.05 + 0.1 * torch.randn(src_layer.n, d1.n), update_rule=PostPre ) net.add_connection(conn, c2_name(f_idx, g_size), d1_name()) d2 = LIFNodes(n=DEEP_LAYERS_N, traces=True) net.add_layer(layer=d2, name=d2_name()) d1_d2_conn = Connection( source=d1, target=d2, w=0.05 + 0.1 * torch.randn(d1.n, d2.n), update_rule=PostPre ) net.add_connection(d1_d2_conn, d1_name(), d2_name()) r = LIFNodes(n=len(TARGETS), traces=True) net.add_layer(layer=r, name="R") d2_r_conn = Connection( source=d2, target=r, w=0.05 + 0.05 * torch.randn(d2.n, r.n), update_rule=PostPre ) net.add_connection(d2_r_conn, d2_name(), r_name()) r_rec = Connection( source=r, target=r, w=0.5 * (torch.eye(r.n) - 1), decay=0, ) net.add_connection(r_rec, r_name(), r_name()) net.add_monitor( Monitor(net.layers[r_name()], ["s"]), "result" )
# Build network. network = Network() input_layer = Input(n=32 * 32 * 3, shape=(1, 3, 32, 32), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), thresh=-64.0, traces=True, theta_plus=0.05, refrac=0) conv_layer2 = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), refrac=0) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=Hebbian, norm=0.5 * int(np.sqrt(total_kernel_size)), nu=(1e-3, 1e-3), wmax=2.0) conv_conn2 = Conv2dConnection(input_layer, conv_layer2, w=conv_conn.w, kernel_size=kernel_size, stride=stride, update_rule=None, nu=(0, 1e-3), wmax=2.0) w = torch.ones(1, n_filters, conv_size[0], conv_size[1], 1, n_filters, conv_size[0], conv_size[1]) for f in range(n_filters):
def main(args): # Random seed. if args.gpu and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) else: torch.manual_seed(args.seed) # Determines number of workers. if args.n_workers == -1: args.n_workers = args.gpu * 4 * torch.cuda.device_count() # Build network. network = bindsnet.network.Network(dt=args.dt, batch_size=args.batch_size) # Layers. input_layer = Input(shape=(1, 28, 28), traces=True) conv1_layer = LIFNodes(shape=(20, 24, 24), traces=True) pool1_layer = PassThroughNodes(shape=(20, 12, 12), traces=True) conv2_layer = LIFNodes(shape=(50, 8, 8), traces=True) pool2_layer = PassThroughNodes(shape=(50, 4, 4), traces=True) dense_layer = LIFNodes(shape=(200, ), traces=True) output_layer = LIFNodes(shape=(10, ), traces=True) network.add_layer(input_layer, name="I") network.add_layer(conv1_layer, name="C1") network.add_layer(pool1_layer, name="P1") network.add_layer(conv2_layer, name="C2") network.add_layer(pool2_layer, name="P2") network.add_layer(dense_layer, name="D") network.add_layer(output_layer, name="O") # Connections. conv1_connection = Conv2dConnection( source=input_layer, target=conv1_layer, update_rule=WeightDependentPost, nu=(0.0, args.nu), kernel_size=5, stride=1, wmin=-1.0, wmax=1.0, ) pool1_connection = SpatialPooling2dConnection(source=conv1_layer, target=pool1_layer, kernel_size=2, stride=2) conv2_connection = Conv2dConnection( source=pool1_layer, target=conv2_layer, update_rule=WeightDependentPost, nu=(0.0, args.nu), kernel_size=5, stride=1, wmin=-1.0, wmax=1.0, ) pool2_connection = SpatialPooling2dConnection(source=conv2_layer, target=pool2_layer, kernel_size=2, stride=2) dense_connection = Connection( source=pool2_layer, target=dense_layer, update_rule=WeightDependentPost, nu=(0.0, args.nu), wmin=-1.0, wmax=1.0, ) output_connection = Connection( source=dense_layer, target=output_layer, update_rule=WeightDependentPost, nu=(0.0, args.nu), wmin=-1.0, wmax=1.0, ) network.add_connection(connection=conv1_connection, source="I", target="C1") network.add_connection(connection=pool1_connection, source="C1", target="P1") network.add_connection(connection=conv2_connection, source="P1", target="C2") network.add_connection(connection=pool2_connection, source="C2", target="P2") network.add_connection(connection=dense_connection, source="P2", target="D") network.add_connection(connection=output_connection, source="D", target="O") # Monitors. for name, layer in network.layers.items(): monitor = Monitor(obj=layer, state_vars=("s", ), time=args.time) network.add_monitor(monitor=monitor, name=name) # Directs network to GPU. if args.gpu: network.to("cuda") # Load MNIST data. dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(bindsnet.ROOT_DIR, "data", "MNIST"), download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) spike_ims = None spike_axes = None conv1_weights_im = None conv2_weights_im = None dense_weights_im = None output_weights_im = None for epoch in range(args.n_epochs): # Create a dataloader to iterate over dataset. dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=args.gpu, ) for step, batch in enumerate(tqdm(dataloader)): # Prep next input batch. inpts = {"I": batch["encoded_image"]} if args.gpu: inpts = {k: v.cuda() for k, v in inpts.items()} # Run the network on the input. network.run(inpts=inpts, time=args.time) # Plot simulation data. if args.plot: spikes = {} for name, monitor in network.monitors.items(): spikes[name] = monitor.get("s")[:, 0].view(args.time, -1) spike_ims, spike_axes = plot_spikes(spikes, ims=spike_ims, axes=spike_axes) conv1_weights_im = plot_conv2d_weights(conv1_connection.w, im=conv1_weights_im, wmin=-1.0, wmax=1.0) conv2_weights_im = plot_conv2d_weights(conv2_connection.w, im=conv2_weights_im, wmin=-1.0, wmax=1.0) dense_weights_im = plot_weights(dense_connection.w, im=dense_weights_im, wmin=-1.0, wmax=1.0) output_weights_im = plot_weights(output_connection.w, im=output_weights_im, wmin=-1.0, wmax=1.0) plt.pause(1e-8) # Reset state variables. network.reset_()
def main(seed=0, n_train=60000, n_test=10000, kernel_size=16, stride=4, n_filters=25, padding=0, inhib=500, lr=0.01, lr_decay=0.99, time=50, dt=1, intensity=1, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) if not train: update_interval = n_test if kernel_size == 32: conv_size = 1 else: conv_size = int((32 - kernel_size + 2 * padding) / stride) + 1 per_class = int((n_filters * conv_size * conv_size) / 10) # Build network. network = Network() input_layer = Input(n=1024, shape=(1, 1, 32, 32), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * conv_size * conv_size, shape=(1, n_filters, conv_size, conv_size), traces=True) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=0.4 * kernel_size**2, nu=[0, lr], wmin=0, wmax=1) w = -inhib * torch.ones(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for f in range(n_filters): for i in range(conv_size): for j in range(conv_size): w[f, i, j, f, i, j] = 0 w = w.view(n_filters * conv_size**2, n_filters * conv_size**2) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(recurrent_conn, source='Y', target='Y') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load CIFAR-10 data. dataset = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'), download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images.mean(-1) # Lazily encode data as Poisson spike trains. data_loader = poisson_loader(data=images, time=time, dt=dt) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) voltages = {} for layer in set(network.layers) - {'X'}: voltages[layer] = Monitor(network.layers[layer], state_vars=['v'], time=time) network.add_monitor(voltages[layer], name='%s_voltages' % layer) inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None voltage_ims = None voltage_axes = None # Train the network. print('Begin training.\n') start = t() for i in range(n_train): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_train, t() - start)) start = t() if train and i > 0: network.connections['X', 'Y'].nu[1] *= lr_decay # Get next input sample. sample = next(data_loader).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) # Optionally plot various simulation information. if plot: # inpt = inpts['X'].view(time, 1024).sum(0).view(32, 32) weights1 = conv_conn.w _spikes = { 'X': spikes['X'].get('s').view(32**2, time), 'Y': spikes['Y'].get('s').view(n_filters * conv_size**2, time) } _voltages = { 'Y': voltages['Y'].get('v').view(n_filters * conv_size**2, time) } # inpt_axes, inpt_ims = plot_input( # images[i].view(32, 32), inpt, label=labels[i], axes=inpt_axes, ims=inpt_ims # ) # voltage_ims, voltage_axes = plot_voltages(_voltages, ims=voltage_ims, axes=voltage_axes) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights(weights1, im=weights_im) plt.pause(1e-8) network.reset_() # Reset state variables. print('Progress: %d / %d (%.4f seconds)\n' % (n_train, n_train, t() - start)) print('Training complete.\n')