def test(network, data, labels): activities = torch.zeros(len(data), RUN_TIME, len(SUBJECTS)) # data_size * run_time * classes true_labels = torch.from_numpy(np.array(labels)) for index, image_batch in enumerate(tqdm(data)): network_input = encode_image_batch(image_batch) network.run(network_input, time=RUN_TIME) spikes = network.monitors["OUT"].get("s") activities[index, :, :] = spikes[-RUN_TIME:, 0] assignments = assign_labels(activities, true_labels, len(SUBJECTS)) predicated_labels = all_activity(activities, assignments[0], len(SUBJECTS)) print(classification_report(true_labels, predicated_labels))
def update_curves(curves: Dict[str, list], labels: torch.Tensor, n_classes: int, **kwargs) -> Tuple[Dict[str, list], Dict[str, torch.Tensor]]: # language=rst """ Updates accuracy curves for each classification scheme. :param curves: Mapping from name of classification scheme to list of accuracy evaluations. :param labels: One-dimensional ``torch.Tensor`` of integer data labels. :param n_classes: Number of data categories. :param kwargs: Additional keyword arguments for classification scheme evaluation functions. :return: Updated accuracy curves and predictions. """ predictions = {} for scheme in curves: # Branch based on name of classification scheme if scheme == 'all': spike_record = kwargs['spike_record'] assignments = kwargs['assignments'] prediction = all_activity(spike_record, assignments, n_classes) elif scheme == 'proportion': spike_record = kwargs['spike_record'] assignments = kwargs['assignments'] proportions = kwargs['proportions'] prediction = proportion_weighting(spike_record, assignments, proportions, n_classes) elif scheme == 'ngram': spike_record = kwargs['spike_record'] ngram_scores = kwargs['ngram_scores'] n = kwargs['n'] prediction = ngram(spike_record, ngram_scores, n_classes, n) elif scheme == 'logreg': full_spike_record = kwargs['full_spike_record'] logreg = kwargs['logreg'] prediction = logreg_predict(spikes=full_spike_record, logreg=logreg) else: raise NotImplementedError # Compute accuracy with current classification scheme. predictions[scheme] = prediction accuracy = torch.sum(labels.long() == prediction).float() / len(labels) curves[scheme].append(100 * accuracy) return curves, predictions
def predict(labeled_batches): print(f"predicting {len(labeled_batches)} batches...") n_samples = len(labeled_batches) n_classes = len(TARGETS) true_labels = torch.zeros(n_samples) activities = torch.zeros(n_samples, ENCODE_WINDOW, n_classes) sample_idx = 0 for label, img_batch in tqdm(labeled_batches): run_sinle_batch(img_batch) activities[sample_idx, :, :] = get_result_activity() sample_idx += 1 assignments, _, _ = assign_labels(activities, true_labels, n_classes) pred_labels = all_activity(activities, assignments, n_classes) return pred_labels
network.add_monitor(spikes[layer], name="%s_spikes" % layer) # Train the network. print("Begin training.\n") pbar = tqdm(enumerate(dataloader)) for (i, dataPoint) in pbar: if i > n_train: break image = dataPoint["encoded_image"] label = dataPoint["label"] pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train)) if i % update_interval == 0 and i > 0: # Get network predictions. all_activity_pred = all_activity(spike_record, assignments, 10) proportion_pred = proportion_weighting(spike_record, assignments, proportions, 10) # Compute network accuracy according to available classification strategies. accuracy["all"].append( 100 * torch.sum(label.long() == all_activity_pred).item() / update_interval) accuracy["proportion"].append( 100 * torch.sum(label.long() == proportion_pred).item() / update_interval) print( "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)" % (accuracy["all"][-1], np.mean( accuracy["all"]), np.max(accuracy["all"])))
def main(args): if args.update_steps is None: args.update_steps = max( 250 // args.batch_size, 1 ) #Its value is 16 # why is it always multiplied with step? #update_steps is how many batch to classify before updating the graphs update_interval = args.update_steps * args.batch_size # Value is 240 #update_interval is how many pictures to classify before updating the graphs # Sets up GPU use torch.backends.cudnn.benchmark = False if args.gpu and torch.cuda.is_available(): torch.cuda.manual_seed_all( args.seed ) #to enable reproducability of the code to get the same result else: torch.manual_seed(args.seed) # Determines number of workers to use if args.n_workers == -1: args.n_workers = args.gpu * 4 * torch.cuda.device_count() n_sqrt = int(np.ceil(np.sqrt(args.n_neurons))) if args.reduction == "sum": #could have used switch to improve performance reduction = torch.sum #weight updates for the batch elif args.reduction == "mean": reduction = torch.mean elif args.reduction == "max": reduction = max_without_indices else: raise NotImplementedError # Build network. network = DiehlAndCook2015v2( #Changed here n_inpt=784, # input dimensions are 28x28=784 n_neurons=args.n_neurons, inh=args.inh, dt=args.dt, norm=78.4, nu=(1e-4, 1e-2), reduction=reduction, theta_plus=args.theta_plus, inpt_shape=(1, 28, 28), ) # Directs network to GPU if args.gpu: network.to("cuda") # Load MNIST data. dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=True, transform=transforms.Compose( #Composes several transforms together [ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) test_dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) # Neuron assignments and spike proportions. n_classes = 10 #changed assignments = -torch.ones(args.n_neurons) #assignments is set to -1 proportions = torch.zeros(args.n_neurons, n_classes) #matrix of 100x10 filled with zeros rates = torch.zeros(args.n_neurons, n_classes) #matrix of 100x10 filled with zeros # Set up monitors for spikes and voltages spikes = {} for layer in set(network.layers): spikes[layer] = Monitor( network.layers[layer], state_vars=["s"], time=args.time ) # Monitors: Records state variables of interest. obj:An object to record state variables from during network simulation. network.add_monitor( spikes[layer], name="%s_spikes" % layer ) #state_vars: Iterable of strings indicating names of state variables to record. #param time: If not ``None``, pre-allocate memory for state variable recording. weights_im = None spike_ims, spike_axes = None, None # Record spikes for length of update interval. spike_record = torch.zeros(update_interval, args.time, args.n_neurons) if os.path.isdir( args.log_dir): #checks if the path is a existing directory shutil.rmtree( args.log_dir) # is used to delete an entire directory tree # Summary writer. writer = SummaryWriter( log_dir=args.log_dir, flush_secs=60 ) #SummaryWriter: these utilities let you log PyTorch models and metrics into a directory for visualization #flush_secs: in seconds, to flush the pending events and summaries to disk. for epoch in range(args.n_epochs): #default is 1 print("\nEpoch: {epoch}\n") labels = [] # Create a dataloader to iterate and batch data dataloader = DataLoader( #It represents a Python iterable over a dataset dataset, batch_size=args.batch_size, #how many samples per batch to load shuffle= True, #set to True to have the data reshuffled at every epoch num_workers=args.n_workers, pin_memory=args. gpu, #If True, the data loader will copy Tensors into CUDA pinned memory before returning them. ) for step, batch in enumerate( dataloader ): #Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object print("Step:", step) global_step = 60000 * epoch + args.batch_size * step if step % args.update_steps == 0 and step > 0: # Convert the array of labels into a tensor label_tensor = torch.tensor(labels) # Get network predictions. all_activity_pred = all_activity(spikes=spike_record, assignments=assignments, n_labels=n_classes) proportion_pred = proportion_weighting( spikes=spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) writer.add_scalar( tag="accuracy/all vote", scalar_value=torch.mean( (label_tensor.long() == all_activity_pred).float()), global_step=global_step, ) #Vennila: Records the accuracies in each step value = torch.mean( (label_tensor.long() == all_activity_pred).float()) value = value.item() accuracy.append(value) print("ACCURACY:", value) writer.add_scalar( tag="accuracy/proportion weighting", scalar_value=torch.mean( (label_tensor.long() == proportion_pred).float()), global_step=global_step, ) writer.add_scalar( tag="spikes/mean", scalar_value=torch.mean(torch.sum(spike_record, dim=1)), global_step=global_step, ) square_weights = get_square_weights( network.connections["X", "Y"].w.view(784, args.n_neurons), n_sqrt, 28, ) img_tensor = colorize(square_weights, cmap="hot_r") writer.add_image( tag="weights", img_tensor=img_tensor, global_step=global_step, dataformats="HWC", ) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spikes=spike_record, labels=label_tensor, n_labels=n_classes, rates=rates, ) labels = [] labels.extend( batch["label"].tolist() ) #for each batch or 16 pictures the labels of it is added to this list # Prep next input batch. inpts = {"X": batch["encoded_image"]} if args.gpu: inpts = { k: v.cuda() for k, v in inpts.items() } #.cuda() is used to set up and run CUDA operations in the selected GPU # Run the network on the input. t0 = time() network.run(inputs=inpts, time=args.time, one_step=args.one_step ) # Simulate network for given inputs and time. t1 = time() - t0 # Add to spikes recording. s = spikes["Y"].get("s").permute((1, 0, 2)) spike_record[(step * args.batch_size) % update_interval:(step * args.batch_size % update_interval) + s.size(0)] = s writer.add_scalar(tag="time/simulation", scalar_value=t1, global_step=global_step) # if(step==1): # input_exc_weights = network.connections["X", "Y"].w # an_array = input_exc_weights.detach().cpu().clone().numpy() # #print(np.shape(an_array)) # data = asarray(an_array) # savetxt('data.csv',data) # print("Beginning weights saved") # if(step==3749): # input_exc_weights = network.connections["X", "Y"].w # an_array = input_exc_weights.detach().cpu().clone().numpy() # #print(np.shape(an_array)) # data2 = asarray(an_array) # savetxt('data2.csv',data2) # print("Ending weights saved") # Plot simulation data. if args.plot: input_exc_weights = network.connections["X", "Y"].w # print("Weights:",input_exc_weights) square_weights = get_square_weights( input_exc_weights.view(784, args.n_neurons), n_sqrt, 28) spikes_ = { layer: spikes[layer].get("s")[:, 0] for layer in spikes } spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) plt.pause(1e-8) # Reset state variables. network.reset_state_variables() print(end_accuracy()) #Vennila
inputs = {"X": a} #print(inputs["X"].sum()/32) else: inputs = {"X": batch["encoded_image"]} #print(inputs["X"].sum()/32) if gpu: inputs = {k: v.cuda() for k, v in inputs.items()} if step % update_steps == 0 and step > 0: # Convert the array of labels into a tensor label_tensor = torch.tensor(labels) # Get network predictions. all_activity_pred = all_activity(spikes=spike_record, assignments=assignments, n_labels=n_classes) proportion_pred = proportion_weighting( spikes=spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) # Compute network accuracy according to available classification strategies. accuracy["all"].append( 100 * torch.sum(label_tensor.long() == all_activity_pred).item() / len(label_tensor)) accuracy["proportion"].append( 100 *
voltage_axes = None voltage_ims = None pbar = tqdm(enumerate(dataloader_train)) for (i, datum) in pbar: if i > n_train: break image = datum["encoded_image"] label = datum["label"] pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train)) #Print training accuracy if i % update_interval == 0 and i > 0: # Get network predictions. all_activity_pred = all_activity(spike_record, assignments, num_classes) proportion_pred = proportion_weighting(spike_record, assignments, proportions, num_classes) # Compute network accuracy according to available classification strategies. accuracy["all"].append( 100 * torch.sum(labels.long() == all_activity_pred).item() / update_interval) accuracy["proportion"].append( 100 * torch.sum(labels.long() == proportion_pred).item() / update_interval) print( "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)" % (accuracy["all"][-1], np.mean( accuracy["all"]), np.max(accuracy["all"])))
if i % update_interval == 0 and i > 0: # Get a tensor of labels label_tensor = torch.Tensor(labels).to(device) # Get network predictions. if use_mnist: confusion = DataFrame([[0] * n_classes for _ in range(n_classes)]) else: confusion = DataFrame([[0] * n_classes for _ in range(n_classes)], columns=kws, index=kws) all_activity_pred = all_activity(spike_record.to('cpu'), assignments.to('cpu'), n_classes).to(device) proportion_pred = proportion_weighting(spike_record.to('cpu'), assignments.to('cpu'), proportions.to('cpu'), n_classes).to(device) for j in range(len(label_tensor)): true_idx = label_tensor[j].long().item() pred_idx = all_activity_pred[j].item() if use_mnist: confusion[true_idx][pred_idx] += 1 else: confusion[kws[true_idx]][kws[pred_idx]] += 1 # Compute network accuracy accuracy['all'].append(
def main(args): update_interval = args.update_steps * args.batch_size # Sets up GPU use torch.backends.cudnn.benchmark = False if args.gpu and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) else: torch.manual_seed(args.seed) # Determines number of workers to use if args.n_workers == -1: args.n_workers = args.gpu * 4 * torch.cuda.device_count() n_sqrt = int(np.ceil(np.sqrt(args.n_neurons))) if args.reduction == "sum": reduction = torch.sum elif args.reduction == "mean": reduction = torch.mean elif args.reduction == "max": reduction = max_without_indices else: raise NotImplementedError # Build network. network = DiehlAndCook2015v2( n_inpt=784, n_neurons=args.n_neurons, inh=args.inh, dt=args.dt, norm=78.4, nu=(0.0, 1e-2), reduction=reduction, theta_plus=args.theta_plus, inpt_shape=(1, 28, 28), ) # Directs network to GPU. if args.gpu: network.to("cuda") # Load MNIST data. dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) dataset, valid_dataset = torch.utils.data.random_split( dataset, [59000, 1000]) test_dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) # Neuron assignments and spike proportions. n_classes = 10 assignments = -torch.ones(args.n_neurons) proportions = torch.zeros(args.n_neurons, n_classes) rates = torch.zeros(args.n_neurons, n_classes) # Set up monitors for spikes and voltages spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=args.time) network.add_monitor(spikes[layer], name="%s_spikes" % layer) weights_im = None spike_ims, spike_axes = None, None # Record spikes for length of update interval. spike_record = torch.zeros(update_interval, args.time, args.n_neurons) if os.path.isdir(args.log_dir): shutil.rmtree(args.log_dir) # Summary writer. writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60) for epoch in range(args.n_epochs): print(f"\nEpoch: {epoch}\n") labels = [] # Get training data loader. dataloader = DataLoader( dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=args.gpu, ) for step, batch in enumerate(dataloader): print(f"Step: {step} / {len(dataloader)}") global_step = 60000 * epoch + args.batch_size * step if step % args.update_steps == 0 and step > 0: # Disable learning. network.train(False) # Get test data loader. valid_dataloader = DataLoader( dataset=valid_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=args.gpu, ) test_labels = [] test_spike_record = torch.zeros(len(valid_dataset), args.time, args.n_neurons) t0 = time() for test_step, test_batch in enumerate(valid_dataloader): # Prep next input batch. inpts = {"X": test_batch["encoded_image"]} if args.gpu: inpts = {k: v.cuda() for k, v in inpts.items()} # Run the network on the input (inference mode). network.run(inpts=inpts, time=args.time, one_step=args.one_step) # Add to spikes recording. s = spikes["Y"].get("s").permute((1, 0, 2)) test_spike_record[(test_step * args.test_batch_size ):(test_step * args.test_batch_size) + s.size(0)] = s # Plot simulation data. if args.valid_plot: input_exc_weights = network.connections["X", "Y"].w square_weights = get_square_weights( input_exc_weights.view(784, args.n_neurons), n_sqrt, 28) spikes_ = { layer: spikes[layer].get("s")[:, 0] for layer in spikes } spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) plt.pause(1e-8) # Reset state variables. network.reset_() test_labels.extend(test_batch["label"].tolist()) t1 = time() - t0 writer.add_scalar(tag="time/test", scalar_value=t1, global_step=global_step) # Convert the list of labels into a tensor. test_label_tensor = torch.tensor(test_labels) # Get network predictions. all_activity_pred = all_activity( spikes=test_spike_record, assignments=assignments, n_labels=n_classes, ) proportion_pred = proportion_weighting( spikes=test_spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) writer.add_scalar( tag="accuracy/valid/all vote", scalar_value=100 * torch.mean( (test_label_tensor.long() == all_activity_pred).float()), global_step=global_step, ) writer.add_scalar( tag="accuracy/valid/proportion weighting", scalar_value=100 * torch.mean( (test_label_tensor.long() == proportion_pred).float()), global_step=global_step, ) square_weights = get_square_weights( network.connections["X", "Y"].w.view(784, args.n_neurons), n_sqrt, 28, ) img_tensor = colorize(square_weights, cmap="hot_r") writer.add_image( tag="weights", img_tensor=img_tensor, global_step=global_step, dataformats="HWC", ) # Convert the array of labels into a tensor label_tensor = torch.tensor(labels) # Get network predictions. all_activity_pred = all_activity(spikes=spike_record, assignments=assignments, n_labels=n_classes) proportion_pred = proportion_weighting( spikes=spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) writer.add_scalar( tag="accuracy/train/all vote", scalar_value=100 * torch.mean( (label_tensor.long() == all_activity_pred).float()), global_step=global_step, ) writer.add_scalar( tag="accuracy/train/proportion weighting", scalar_value=100 * torch.mean( (label_tensor.long() == proportion_pred).float()), global_step=global_step, ) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spikes=spike_record, labels=label_tensor, n_labels=n_classes, rates=rates, ) # Re-enable learning. network.train(True) labels = [] labels.extend(batch["label"].tolist()) # Prep next input batch. inpts = {"X": batch["encoded_image"]} if args.gpu: inpts = {k: v.cuda() for k, v in inpts.items()} # Run the network on the input (training mode). t0 = time() network.run(inpts=inpts, time=args.time, one_step=args.one_step) t1 = time() - t0 writer.add_scalar(tag="time/train/step", scalar_value=t1, global_step=global_step) # Add to spikes recording. s = spikes["Y"].get("s").permute((1, 0, 2)) spike_record[(step * args.batch_size) % update_interval:(step * args.batch_size % update_interval) + s.size(0)] = s # Plot simulation data. if args.plot: input_exc_weights = network.connections["X", "Y"].w square_weights = get_square_weights( input_exc_weights.view(784, args.n_neurons), n_sqrt, 28) spikes_ = { layer: spikes[layer].get("s")[:, 0] for layer in spikes } spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) plt.pause(1e-8) # Reset state variables. network.reset_()
def train(self, config=None): if config is None: cfg = self.cfg update_interval = cfg['update_interval'] time = cfg['time'] n_neurons = cfg['network']['n_neurons'] dataset, n_classes = self._init_dataset(cfg) # Record spikes during the simulation spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions assignments = -torch.ones(n_neurons) proportions = torch.zeros(n_neurons, n_classes) rates = torch.zeros(n_neurons, n_classes) # Sequence of accuracy estimates accuracy = {"all": [], "proportion": []} # Set up monitors for spikes and voltages exc_voltage_monitor, inh_voltage_monitor, spikes, voltages = self._init_network_monitor( self.network, cfg) inpt_ims, inpt_axes = None, None spike_ims, spike_axes = None, None weights_im = None assigns_im = None perf_ax = None voltage_axes, voltage_ims = None, None print("\nBegin training.\n") iteration = 0 for epoch in range(cfg['epochs']): print("Progress: %d / %d" % (epoch, cfg['epochs'])) labels = [] start_time = T.time() dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=cfg['n_workers']) for step, batch in enumerate(tqdm(dataloader)): # Get next input sample. inputs = {'X': batch["encoded_image"].view(time, 1, 1, 28, 28)} inputs = {k: v.to(self.device) for k, v in inputs.items()} if step % update_interval == 0 and step > 0: # Convert the array of labels into a tensor label_tensor = torch.tensor(labels) # Get network predictions. all_activity_pred = all_activity(spikes=spike_record, assignments=assignments, n_labels=n_classes) proportion_pred = proportion_weighting( spikes=spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) # Compute network accuracy according to available classification strategies. accuracy["all"].append(100 * torch.sum( label_tensor.long() == all_activity_pred).item() / len(label_tensor)) accuracy["proportion"].append(100 * torch.sum( label_tensor.long() == proportion_pred).item() / len(label_tensor)) iteration += len(label_tensor) print( "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)" % ( accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]), )) print( "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n" % ( accuracy["proportion"][-1], np.mean(accuracy["proportion"]), np.max(accuracy["proportion"]), )) self.recorder.insert( (iteration, accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]), accuracy["proportion"][-1], np.mean(accuracy["proportion"]), np.max(accuracy["proportion"]))) assignments, proportions, rates = assign_labels( spikes=spike_record, labels=label_tensor, n_labels=n_classes, rates=rates, ) labels = [] labels.append(batch["label"]) # Run the network on the input. self.network.run(inputs=inputs, time=time, input_time_dim=1) # Get voltage recording. exc_voltages = exc_voltage_monitor.get("v") inh_voltages = inh_voltage_monitor.get("v") # Add to spikes recording. spike_record[step % update_interval] = spikes["Ae"].get( "s").squeeze() # Reset state variables self.network.reset_state_variables() if step % 1000 == 0: self.save(cfg=cfg) print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, cfg['epochs'], T.time() - start_time)) self.recorder.write(self.save_dir, cfg['name']) print("Training complete.\n") return None
perf_ax = None voltage_axes = None voltage_ims = None pbar = tqdm(enumerate(dataloader)) for (i, datum) in pbar: if i > n_train: break image = datum["encoded_image"] label = datum["label"] pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train)) if i % update_interval == 0 and i > 0: # Get network predictions. all_activity_pred = all_activity(spike_record, assignments, 10) proportion_pred = proportion_weighting( spike_record, assignments, proportions, 10 ) # Compute network accuracy according to available classification strategies. accuracy["all"].append( 100 * torch.sum(labels.long() == all_activity_pred).item() / update_interval ) accuracy["proportion"].append( 100 * torch.sum(labels.long() == proportion_pred).item() / update_interval ) print( "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)" % (accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]))
def main(): seed = 0 #random seed n_neurons = 100 # number of neurons per layer n_train = 60000 # number of traning examples to go through n_epochs = 1 inh = 120.0 # strength of synapses from inh. layer to exci. layer exc = 22.5 lr = 1e-2 # learning rate lr_decay = 0.99 # learning rate decay time = 350 # duration of each sample after running through possion encoder dt = 1 # timestep theta_plus = 0.05 # post spike threshold increase tc_theta_decay = 1e7 # threshold decay intensity = 0.25 # number to multiply input Diehl Cook maja 0.25 progress_interval = 10 update_interval = 250 plot = False gpu = False load_network = False # load network from disk n_classes = 10 n_sqrt = int(np.ceil(np.sqrt(n_neurons))) # TRAINING save_weights_fn = "plots_snn/weights/weights_train.png" save_performance_fn = "plots_snn/performance/performance_train.png" save_assaiments_fn = "plots_snn/assaiments/assaiments_train.png" directorys = [ "plots_snn", "plots_snn/weights", "plots_snn/performance", "plots_snn/assaiments" ] for directory in directorys: if not os.path.exists(directory): os.makedirs(directory) assert n_train % update_interval == 0 np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) # Build network if load_network: network = load('net_output.pt') # here goes file with network to load else: network = DiehlAndCook2015( n_inpt=784, n_neurons=n_neurons, exc=exc, inh=inh, dt=dt, norm=78.4, nu=(1e-4, lr), theta_plus=theta_plus, inpt_shape=(1, 28, 28), ) if gpu: network.to("cuda") # Pull dataset data, targets = torch.load( 'data/MNIST/TorchvisionDatasetWrapper/processed/training.pt') data = data * intensity trainset = torch.utils.data.TensorDataset(data, targets) trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1) # Spike recording spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_train, n_neurons).long() # Intialization if load_network: assignments, proportions, rates, ngram_scores = torch.load( 'parameter_output.pt') else: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) ngram_scores = {} curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} best_accuracy = 0 # Initilize spike records spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) i = 0 current_labels = torch.zeros(update_interval) inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None # train train_time = t.time() current_labels = torch.zeros(update_interval) time1 = t.time() for j in range(n_epochs): i = 0 for sample, label in trainloader: if i >= n_train: break if i % progress_interval == 0: print(f'Progress: {i} / {n_train} took {(t.time()-time1)} s') time1 = t.time() if i % update_interval == 0 and i > 0: #network.connections['X','Y'].update_rule.nu[1] *= lr_decay curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Accuracy curves if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network and parameters to disk. network.save(os.path.join('net_output.pt')) path = "parameters_output.pt" torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) sample_enc = poisson(datum=sample, time=time, dt=dt) inpts = {'X': sample_enc} # Run the network on the input. network.run(inputs=inpts, time=time) retries = 0 # Spikes reocrding spike_record[i % update_interval] = spikes['Ae'].get('s').view( time, n_neurons) full_spike_record[i] = spikes['Ae'].get('s').view( time, n_neurons).sum(0).long() if plot: _input = sample.view(28, 28) reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Ae')].w square_assignments = get_square_assignments( assignments, n_sqrt) assigns_im = plot_assignments(square_assignments, im=assigns_im) if i % update_interval == 0: square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) weights_im = plot_weights(square_weights, im=weights_im) [weights_im, save_weights_fn] = plot_weights(square_weights, im=weights_im, save=save_weights_fn) inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=label, axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) assigns_im = plot_assignments(square_assignments, im=assigns_im, save=save_assaiments_fn) perf_ax = plot_performance(curves, ax=perf_ax, save=save_performance_fn) plt.pause(1e-8) current_labels[i % update_interval] = label[0] network.reset_state_variables() if i % 10 == 0 and i > 0: preds = all_activity( spike_record[i % update_interval - 10:i % update_interval], assignments, n_classes) print(f'Predictions: {(preds * 1.0).numpy()}') print( f'True value: {current_labels[i % update_interval - 10:i % update_interval].numpy()}' ) i += 1 print(f'Number of epochs {j}/{n_epochs+1}') torch.save(network.state_dict(), 'net_final.pt') path = "parameters_final.pt" torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) print("Training completed. Training took " + str((t.time() - train_time) / 6) + " min.") print("Saving network...") network.save(os.path.join('net_final.pt')) torch.save((assignments, proportions, rates, ngram_scores), open('parameters_final.pt', 'wb')) print("Network saved.")