Example #1
0
def validate(loader, model, encoder, criterion, opt):
    num_images = 0
    total_loss = 0.0
    num_corrects = 0

    for image, label in loader:
        image = image.squeeze(dim=0).cuda()
        label = label.squeeze().cuda()

        spiked_image = encoder(image)
        spiked_image = spiked_image.view(spiked_image.size(0), -1)

        spiked_label = label_encoder(label, opt.beta, opt.num_classes, opt.time_interval)

        loss_buffer = []

        for t in range(opt.time_interval):
            model(spiked_image[t])

            loss_buffer.append(model.fc2.o.clone())

        model.reset_variables(w=False)

        num_images += 1
        num_corrects += accuracy(r=torch.stack(loss_buffer), label=label)
        total_loss += criterion(r=torch.stack(loss_buffer), z=spiked_label, label=label, epsilon=opt.epsilon)

    return total_loss/num_images, float(num_corrects)/num_images
Example #2
0
def app(opt):
    print(opt)

    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            opt.data,
            train=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(), torchvision.transforms.Lambda(lambda x: x * 32 * 4)])),
        batch_size=opt.batch_size,
        shuffle=False)

    # Load pretrained weights and thesholds
    state_dict = torch.load(opt.pretrained)['model_state_dict']
    trained_w = state_dict['xe.w']
    trained_th = state_dict['exc.theta']

    model = n3ml.model.DiehlAndCook2015Infer(neurons=opt.neurons)
    model.xe.w.copy_(trained_w)
    model.exc.theta.copy_(trained_th)

    encoder = n3ml.encoder.PoissonEncoder(opt.time_interval)

    total_rates = torch.zeros((opt.num_classes, opt.neurons))
    total_labels = torch.zeros(opt.num_classes)

    start = time.time()

    for step, (image, label) in enumerate(train_loader):
        model.init_param()

        image = image.view(1, 28, 28)

        spiked_image = encoder(image)
        spiked_image = spiked_image.view(opt.time_interval, -1)
        spiked_image = spiked_image.cuda()

        spike_train = []

        for t in range(opt.time_interval):
            model.run({'inp': spiked_image[t]})

            spike_train.append(model.exc.s.clone().detach().cpu())

        spike_train = torch.stack(spike_train)

        total_rates[label] += torch.sum(spike_train, dim=0) / opt.time_interval
        total_labels[label] += 1

        if (step+1) % 1000 == 0:
            end = time.time()
            print("elapsed times: {} - number of images: {}".format(end-start, step+1))

    total_avg_rates = total_rates / total_labels.unsqueeze(dim=1)

    assigned_label = torch.argmax(total_avg_rates, dim=0)

    print(assigned_label)

    torch.save({'assigned_label': assigned_label}, opt.assigned)
Example #3
0
def train(loader, model, encoder, optimizer, criterion, opt) -> None:
    plotter = Plot()

    num_images = 0
    total_loss = 0.0
    num_corrects = 0

    list_loss = []
    list_acc = []

    for image, label in loader:
        # Squeeze batch dimension
        # Now, batch processing isn't supported
        image = image.squeeze(dim=0)
        label = label.squeeze()

        spiked_image = encoder(image)
        spiked_image = spiked_image.view(spiked_image.size(0), -1)

        spiked_label = label_encoder(label, opt.beta, opt.num_classes, opt.time_interval)

        # print(label)
        # print(spiked_label)
        # exit(0)

        # np_spiked_image = spiked_image.numpy()

        spike_buffer = {
            'inp': [],
            'fc1': [],
            'fc2': []
        }

        loss_buffer = []

        print()
        print("label: {}".format(label))

        for t in range(opt.time_interval):
            # print(np_spiked_image[t])

            model(spiked_image[t])

            spike_buffer['inp'].append(spiked_image[t].clone())
            spike_buffer['fc1'].append(model.fc1.o.clone())
            spike_buffer['fc2'].append(model.fc2.o.clone())

            loss_buffer.append(model.fc2.o.clone())

            for l in spike_buffer.values():
                if len(l) > 5:  # TODO: 5를 epsilon을 사용해서 표현해야 함
                    l.pop(0)

            # print(model.fc1.u.numpy())
            # print(model.fc1.o.numpy())
            # print(model.fc2.u.numpy())
            print(model.fc2.o.numpy())

            # time.sleep(1)

            optimizer.step(spike_buffer, spiked_label[t], label)

        model.reset_variables(w=False)

        num_images += 1
        num_corrects += accuracy(r=torch.stack(loss_buffer), label=label)
        total_loss += criterion(r=torch.stack(loss_buffer), z=spiked_label, label=label, epsilon=opt.epsilon)

        if num_images > 0 and num_images % 30 == 0:
            list_loss.append(total_loss / num_images)
            list_acc.append(float(num_corrects) / num_images)

            plotter.update(y1=np.array(list_acc), y2=np.array(list_loss))