def validate(loader, model, encoder, criterion, opt): num_images = 0 total_loss = 0.0 num_corrects = 0 for image, label in loader: image = image.squeeze(dim=0).cuda() label = label.squeeze().cuda() spiked_image = encoder(image) spiked_image = spiked_image.view(spiked_image.size(0), -1) spiked_label = label_encoder(label, opt.beta, opt.num_classes, opt.time_interval) loss_buffer = [] for t in range(opt.time_interval): model(spiked_image[t]) loss_buffer.append(model.fc2.o.clone()) model.reset_variables(w=False) num_images += 1 num_corrects += accuracy(r=torch.stack(loss_buffer), label=label) total_loss += criterion(r=torch.stack(loss_buffer), z=spiked_label, label=label, epsilon=opt.epsilon) return total_loss/num_images, float(num_corrects)/num_images
def app(opt): print(opt) train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST( opt.data, train=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Lambda(lambda x: x * 32 * 4)])), batch_size=opt.batch_size, shuffle=False) # Load pretrained weights and thesholds state_dict = torch.load(opt.pretrained)['model_state_dict'] trained_w = state_dict['xe.w'] trained_th = state_dict['exc.theta'] model = n3ml.model.DiehlAndCook2015Infer(neurons=opt.neurons) model.xe.w.copy_(trained_w) model.exc.theta.copy_(trained_th) encoder = n3ml.encoder.PoissonEncoder(opt.time_interval) total_rates = torch.zeros((opt.num_classes, opt.neurons)) total_labels = torch.zeros(opt.num_classes) start = time.time() for step, (image, label) in enumerate(train_loader): model.init_param() image = image.view(1, 28, 28) spiked_image = encoder(image) spiked_image = spiked_image.view(opt.time_interval, -1) spiked_image = spiked_image.cuda() spike_train = [] for t in range(opt.time_interval): model.run({'inp': spiked_image[t]}) spike_train.append(model.exc.s.clone().detach().cpu()) spike_train = torch.stack(spike_train) total_rates[label] += torch.sum(spike_train, dim=0) / opt.time_interval total_labels[label] += 1 if (step+1) % 1000 == 0: end = time.time() print("elapsed times: {} - number of images: {}".format(end-start, step+1)) total_avg_rates = total_rates / total_labels.unsqueeze(dim=1) assigned_label = torch.argmax(total_avg_rates, dim=0) print(assigned_label) torch.save({'assigned_label': assigned_label}, opt.assigned)
def train(loader, model, encoder, optimizer, criterion, opt) -> None: plotter = Plot() num_images = 0 total_loss = 0.0 num_corrects = 0 list_loss = [] list_acc = [] for image, label in loader: # Squeeze batch dimension # Now, batch processing isn't supported image = image.squeeze(dim=0) label = label.squeeze() spiked_image = encoder(image) spiked_image = spiked_image.view(spiked_image.size(0), -1) spiked_label = label_encoder(label, opt.beta, opt.num_classes, opt.time_interval) # print(label) # print(spiked_label) # exit(0) # np_spiked_image = spiked_image.numpy() spike_buffer = { 'inp': [], 'fc1': [], 'fc2': [] } loss_buffer = [] print() print("label: {}".format(label)) for t in range(opt.time_interval): # print(np_spiked_image[t]) model(spiked_image[t]) spike_buffer['inp'].append(spiked_image[t].clone()) spike_buffer['fc1'].append(model.fc1.o.clone()) spike_buffer['fc2'].append(model.fc2.o.clone()) loss_buffer.append(model.fc2.o.clone()) for l in spike_buffer.values(): if len(l) > 5: # TODO: 5를 epsilon을 사용해서 표현해야 함 l.pop(0) # print(model.fc1.u.numpy()) # print(model.fc1.o.numpy()) # print(model.fc2.u.numpy()) print(model.fc2.o.numpy()) # time.sleep(1) optimizer.step(spike_buffer, spiked_label[t], label) model.reset_variables(w=False) num_images += 1 num_corrects += accuracy(r=torch.stack(loss_buffer), label=label) total_loss += criterion(r=torch.stack(loss_buffer), z=spiked_label, label=label, epsilon=opt.epsilon) if num_images > 0 and num_images % 30 == 0: list_loss.append(total_loss / num_images) list_acc.append(float(num_corrects) / num_images) plotter.update(y1=np.array(list_acc), y2=np.array(list_loss))