def main(args): if args.update_steps is None: args.update_steps = max(250 // args.batch_size, 1) update_interval = args.update_steps * args.batch_size # Sets up GPU use torch.backends.cudnn.benchmark = False if args.gpu and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) else: torch.manual_seed(args.seed) # Determines number of workers to use if args.n_workers == -1: args.n_workers = args.gpu * 4 * torch.cuda.device_count() n_sqrt = int(np.ceil(np.sqrt(args.n_neurons))) if args.reduction == "sum": reduction = torch.sum elif args.reduction == "mean": reduction = torch.mean elif args.reduction == "max": reduction = max_without_indices else: raise NotImplementedError # Build network. network = DiehlAndCook2015v2( n_inpt=784, n_neurons=args.n_neurons, inh=args.inh, dt=args.dt, norm=78.4, nu=(1e-4, 1e-2), reduction=reduction, theta_plus=args.theta_plus, inpt_shape=(1, 28, 28), ) # Directs network to GPU if args.gpu: network.to("cuda") # Load MNIST data. dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) test_dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) # Neuron assignments and spike proportions. n_classes = 10 assignments = -torch.ones(args.n_neurons) proportions = torch.zeros(args.n_neurons, n_classes) rates = torch.zeros(args.n_neurons, n_classes) # Set up monitors for spikes and voltages spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=args.time) network.add_monitor(spikes[layer], name="%s_spikes" % layer) weights_im = None spike_ims, spike_axes = None, None # Record spikes for length of update interval. spike_record = torch.zeros(update_interval, args.time, args.n_neurons) if os.path.isdir(args.log_dir): shutil.rmtree(args.log_dir) # Summary writer. writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60) for epoch in range(args.n_epochs): print(f"\nEpoch: {epoch}\n") labels = [] # Create a dataloader to iterate and batch data dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=args.gpu, ) for step, batch in enumerate(dataloader): print(f"Step: {step}") global_step = 60000 * epoch + args.batch_size * step if step % args.update_steps == 0 and step > 0: # Convert the array of labels into a tensor label_tensor = torch.tensor(labels) # Get network predictions. all_activity_pred = all_activity(spikes=spike_record, assignments=assignments, n_labels=n_classes) proportion_pred = proportion_weighting( spikes=spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) writer.add_scalar( tag="accuracy/all vote", scalar_value=torch.mean( (label_tensor.long() == all_activity_pred).float()), global_step=global_step, ) writer.add_scalar( tag="accuracy/proportion weighting", scalar_value=torch.mean( (label_tensor.long() == proportion_pred).float()), global_step=global_step, ) writer.add_scalar( tag="spikes/mean", scalar_value=torch.mean(torch.sum(spike_record, dim=1)), global_step=global_step, ) square_weights = get_square_weights( network.connections["X", "Y"].w.view(784, args.n_neurons), n_sqrt, 28, ) img_tensor = colorize(square_weights, cmap="hot_r") writer.add_image( tag="weights", img_tensor=img_tensor, global_step=global_step, dataformats="HWC", ) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spikes=spike_record, labels=label_tensor, n_labels=n_classes, rates=rates, ) labels = [] labels.extend(batch["label"].tolist()) # Prep next input batch. inpts = {"X": batch["encoded_image"]} if args.gpu: inpts = {k: v.cuda() for k, v in inpts.items()} # Run the network on the input. t0 = time() network.run(inpts=inpts, time=args.time, one_step=args.one_step) t1 = time() - t0 # Add to spikes recording. s = spikes["Y"].get("s").permute((1, 0, 2)) spike_record[(step * args.batch_size) % update_interval:(step * args.batch_size % update_interval) + s.size(0)] = s writer.add_scalar(tag="time/simulation", scalar_value=t1, global_step=global_step) # Plot simulation data. if args.plot: input_exc_weights = network.connections["X", "Y"].w square_weights = get_square_weights( input_exc_weights.view(784, args.n_neurons), n_sqrt, 28) spikes_ = { layer: spikes[layer].get("s")[:, 0] for layer in spikes } spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) plt.pause(1e-8) # Reset state variables. network.reset_()
network.add_layer(input_layer, name="X") network.add_layer(conv_layer, name="Y") network.add_connection(conv_conn, source="X", target="Y") network.add_connection(recurrent_conn, source="Y", target="Y") # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time) network.add_monitor(voltage_monitor, name="output_voltage") if gpu: network.to("cuda") # Load MNIST data. train_dataset = MNIST( PoissonEncoder(time=time, dt=dt), None, "../../data/MNIST", download=True, train=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]), ) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) network.add_monitor(spikes[layer], name="%s_spikes" % layer) voltages = {}
network.add_connection(FF1a, source="I_a", target="rTNN_1") # (Recurrences) network.add_connection(rTNN_to_buf1, source="rTNN_1", target="BUF_1") network.add_connection(buf1_to_rTNN, source="BUF_1", target="rTNN_1") # End of network creation # Monitors: spikes = {} for l in network.layers: spikes[l] = Monitor(network.layers[l], ["s"], time=num_timesteps) network.add_monitor(spikes[l], name="%s_spikes" % l) # Data and initial encoding: dataset = MNIST( PoissonEncoder(time=num_timesteps, dt=1), None, root=os.path.join("..", "..", "data", "MNIST"), download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]), ) # Create a dataloader to iterate and batch data dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=False)
def create_network(self, norm=0.5, competitive_weight=-100.): self.norm = norm self.competitive_weight = competitive_weight self.time_max = 30 dt = 1 intensity = 127.5 self.train_dataset = MNIST( PoissonEncoder(time=self.time_max, dt=dt), None, "MNIST", download=False, train=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] ) ) # Hyperparameters n_filters = 25 kernel_size = 12 stride = 4 padding = 0 conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1 per_class = int((n_filters * conv_size * conv_size) / 10) tc_trace = 20. # grid search check tc_decay = 20. thresh = -52 refrac = 5 wmin = 0 wmax = 1 # Network self.network = Network(learning=True) self.GlobalMonitor = NetworkMonitor(self.network, state_vars=('v', 's', 'w')) self.input_layer = Input(n=784, shape=(1, 28, 28), traces=True) self.output_layer = AdaptiveLIFNodes( n=n_filters * conv_size * conv_size, shape=(n_filters, conv_size, conv_size), traces=True, thres=thresh, trace_tc=tc_trace, tc_decay=tc_decay, theta_plus=0.05, tc_theta_decay=1e6) self.connection_XY = LocalConnection( self.input_layer, self.output_layer, n_filters=n_filters, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=norm, #1/(kernel_size ** 2),#0.4 * kernel_size ** 2, # norm constant - check nu=[1e-4, 1e-2], wmin=wmin, wmax=wmax) # competitive connections w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: # change for i in range(conv_size): for j in range(conv_size): w[fltr1, i, j, fltr2, i, j] = competitive_weight self.connection_YY = Connection(self.output_layer, self.output_layer, w=w) self.network.add_layer(self.input_layer, name='X') self.network.add_layer(self.output_layer, name='Y') self.network.add_connection(self.connection_XY, source='X', target='Y') self.network.add_connection(self.connection_YY, source='Y', target='Y') self.network.add_monitor(self.GlobalMonitor, name='Network') self.spikes = {} for layer in set(self.network.layers): self.spikes[layer] = Monitor(self.network.layers[layer], state_vars=["s"], time=self.time_max) self.network.add_monitor(self.spikes[layer], name="%s_spikes" % layer) #print('GlobalMonitor.state_vars:', self.GlobalMonitor.state_vars) self.voltages = {} for layer in set(self.network.layers) - {"X"}: self.voltages[layer] = Monitor(self.network.layers[layer], state_vars=["v"], time=self.time_max) self.network.add_monitor(self.voltages[layer], name="%s_voltages" % layer)
def main(args): if args.gpu: torch.cuda.manual_seed_all(args.seed) else: torch.manual_seed(args.seed) conv_size = int((28 - args.kernel_size + 2 * args.padding) / args.stride) + 1 # Build network. network = Network() input_layer = Input(n=784, shape=(1, 28, 28), traces=True) conv_layer = DiehlAndCookNodes( n=args.n_filters * conv_size * conv_size, shape=(args.n_filters, conv_size, conv_size), traces=True, ) conv_conn = Conv2dConnection( input_layer, conv_layer, kernel_size=args.kernel_size, stride=args.stride, update_rule=PostPre, norm=0.4 * args.kernel_size ** 2, nu=[0, args.lr], reduction=max_without_indices, wmax=1.0, ) w = torch.zeros( args.n_filters, conv_size, conv_size, args.n_filters, conv_size, conv_size ) for fltr1 in range(args.n_filters): for fltr2 in range(args.n_filters): if fltr1 != fltr2: for i in range(conv_size): for j in range(conv_size): w[fltr1, i, j, fltr2, i, j] = -100.0 w = w.view( args.n_filters * conv_size * conv_size, args.n_filters * conv_size * conv_size ) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name="X") network.add_layer(conv_layer, name="Y") network.add_connection(conv_conn, source="X", target="Y") network.add_connection(recurrent_conn, source="Y", target="Y") # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers["Y"], ["v"], time=args.time) network.add_monitor(voltage_monitor, name="output_voltage") if args.gpu: network.to("cuda") # Load MNIST data. train_dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity)] ), ) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=args.time) network.add_monitor(spikes[layer], name="%s_spikes" % layer) voltages = {} for layer in set(network.layers) - {"X"}: voltages[layer] = Monitor( network.layers[layer], state_vars=["v"], time=args.time ) network.add_monitor(voltages[layer], name="%s_voltages" % layer) # Train the network. print("Begin training.\n") start = time() weights_im = None for epoch in range(args.n_epochs): if epoch % args.progress_interval == 0: print( "Progress: %d / %d (%.4f seconds)" % (epoch, args.n_epochs, time() - start) ) start = time() train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=args.gpu, ) variance_buffers = {} for k in network.connections.keys(): variance_buffers[k] = {} variance_buffers[k]["prev"] = network.connections[k].w.clone() variance_buffers[k]["sum"] = torch.zeros_like( variance_buffers[k]["prev"], dtype=torch.double ) variance_buffers[k]["sum_squares"] = torch.zeros_like( variance_buffers[k]["prev"], dtype=torch.double ) variance_buffers[k]["count"] = 0 for step, batch in enumerate(tqdm(train_dataloader)): # Get next input sample. inpts = {"X": batch["encoded_image"]} if args.gpu: inpts = {k: v.cuda() for k, v in inpts.items()} # Run the network on the input. network.run(inpts=inpts, time=args.time, input_time_dim=0) # manually compute the total update from this run for k in network.connections.keys(): cur = network.connections[k].w.clone() weight_update = (cur - variance_buffers[k]["prev"]).double() variance_buffers[k]["sum"] += weight_update variance_buffers[k]["sum_squares"] += weight_update * weight_update variance_buffers[k]["count"] += 1 variance_buffers[k]["prev"] = cur if step % 1000 == 999: process_variance_buffers(variance_buffers) # Decay learning rate. network.connections["X", "Y"].nu[1] *= 0.99 # Optionally plot various simulation information. if args.plot: weights = conv_conn.w weights_im = plot_conv2d_weights(weights, im=weights_im) plt.pause(1e-8) network.reset_() # Reset state variables. print( "Progress: %d / %d (%.4f seconds)\n" % (args.n_epochs, args.n_epochs, time() - start) ) print("Training complete.\n") process_variance_buffers(variance_buffers)