Esempio n. 1
0
 def get_test(self, classes, encoder, batch_size):
     dataset = MNISTWrapper(encoder,
                            subset_idx=range(self.n_eval),
                            classes=classes,
                            label_shape=self.label_shape,
                            label_intensity=0.,
                            assignments=self.assignments,
                            root=P.DATA_FOLDER,
                            train=False,
                            download=True,
                            transform=self.T)
     return DataLoader(dataset,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=P.N_WORKERS,
                       pin_memory=P.GPU)
Esempio n. 2
0
 def get_val(self, classes, encoder, batch_size):
     if self.validate_on_tst_set:
         return self.get_test(classes, encoder, batch_size)
     dataset = MNISTWrapper(encoder,
                            subset_idx=range(P.TRN_SET_SIZE - self.n_eval,
                                             P.TRN_SET_SIZE),
                            classes=classes,
                            label_shape=self.label_shape,
                            label_intensity=0.,
                            assignments=self.assignments,
                            root=P.DATA_FOLDER,
                            train=True,
                            download=True,
                            transform=self.T)
     return DataLoader(dataset,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=P.N_WORKERS,
                       pin_memory=P.GPU)
Esempio n. 3
0
    def train(self) -> None:
        # language=rst
        """
        Training loop that runs for the set number of epochs and creates a new
        ``DataLoader`` at each epoch.
        """
        for epoch in range(self.num_epochs):
            train_dataloader = DataLoader(
                self.train_ds,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                shuffle=self.shuffle,
            )

            for step, batch in enumerate(
                    tqdm(
                        train_dataloader,
                        desc="Epoch %d/%d" % (epoch + 1, self.num_epochs),
                        total=len(self.train_ds) // self.batch_size,
                    )):
                self.step(batch)
Esempio n. 4
0
def main(args):
    if args.update_steps is None:
        args.update_steps = max(
            250 // args.batch_size, 1
        )  #Its value is 16 # why is it always multiplied with step? #update_steps is how many batch to classify before updating the graphs

    update_interval = args.update_steps * args.batch_size  # Value is 240 #update_interval is how many pictures to classify before updating the graphs

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(
            args.seed
        )  #to enable reproducability of the code to get the same result
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":  #could have used switch to improve performance
        reduction = torch.sum  #weight updates for the batch
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(  #Changed here
        n_inpt=784,  # input dimensions are 28x28=784
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(1e-4, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose(  #Composes several transforms together
            [
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * args.intensity)
            ]),
    )

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10  #changed
    assignments = -torch.ones(args.n_neurons)  #assignments is set to -1
    proportions = torch.zeros(args.n_neurons,
                              n_classes)  #matrix of 100x10 filled with zeros
    rates = torch.zeros(args.n_neurons,
                        n_classes)  #matrix of 100x10 filled with zeros

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(
            network.layers[layer], state_vars=["s"], time=args.time
        )  # Monitors:  Records state variables of interest. obj:An object to record state variables from during network simulation.
        network.add_monitor(
            spikes[layer], name="%s_spikes" % layer
        )  #state_vars: Iterable of strings indicating names of state variables to record.
        #param time: If not ``None``, pre-allocate memory for state variable recording.
    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(
            args.log_dir):  #checks if the path is a existing directory
        shutil.rmtree(
            args.log_dir)  # is used to delete an entire directory tree

    # Summary writer.
    writer = SummaryWriter(
        log_dir=args.log_dir, flush_secs=60
    )  #SummaryWriter: these utilities let you log PyTorch models and metrics into a directory for visualization
    #flush_secs:  in seconds, to flush the pending events and summaries to disk.
    for epoch in range(args.n_epochs):  #default is 1
        print("\nEpoch: {epoch}\n")

        labels = []

        # Create a dataloader to iterate and batch data
        dataloader = DataLoader(  #It represents a Python iterable over a dataset
            dataset,
            batch_size=args.batch_size,  #how many samples per batch to load
            shuffle=
            True,  #set to True to have the data reshuffled at every epoch
            num_workers=args.n_workers,
            pin_memory=args.
            gpu,  #If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
        )

        for step, batch in enumerate(
                dataloader
        ):  #Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object
            print("Step:", step)

            global_step = 60000 * epoch + args.batch_size * step

            if step % args.update_steps == 0 and step > 0:

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/all vote",
                    scalar_value=torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                #Vennila: Records the accuracies in each step
                value = torch.mean(
                    (label_tensor.long() == all_activity_pred).float())
                value = value.item()
                accuracy.append(value)
                print("ACCURACY:", value)
                writer.add_scalar(
                    tag="accuracy/proportion weighting",
                    scalar_value=torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="spikes/mean",
                    scalar_value=torch.mean(torch.sum(spike_record, dim=1)),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                labels = []

            labels.extend(
                batch["label"].tolist()
            )  #for each batch or 16 pictures the labels of it is added to this list

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {
                    k: v.cuda()
                    for k, v in inpts.items()
                }  #.cuda() is used to set up and run CUDA operations in the selected GPU

            # Run the network on the input.
            t0 = time()
            network.run(inputs=inpts, time=args.time, one_step=args.one_step
                        )  # Simulate network for given inputs and time.
            t1 = time() - t0

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            writer.add_scalar(tag="time/simulation",
                              scalar_value=t1,
                              global_step=global_step)
            # if(step==1):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data = asarray(an_array)
            #     savetxt('data.csv',data)
            #     print("Beginning weights saved")
            # if(step==3749):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data2 = asarray(an_array)
            #     savetxt('data2.csv',data2)
            #     print("Ending weights saved")
            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                # print("Weights:",input_exc_weights)
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_state_variables()
        print(end_accuracy())  #Vennila
Esempio n. 5
0
# Train the network.
print("Begin training.\n")
start = t()

weights_im = None

for epoch in range(n_epochs):
    if epoch % progress_interval == 0:
        print("Progress: %d / %d (%.4f seconds)" %
              (epoch, n_epochs, t() - start))
        start = t()

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=gpu,
    )

    for step, batch in enumerate(tqdm(train_dataloader)):
        # Get next input sample.

        inpts = {"X": batch["encoded_image"]}
        if gpu:
            inpts = {k: v.cuda() for k, v in inpts.items()}
        label = batch["label"]

        # Run the network on the input.
        network.run(inpts=inpts, time=time, input_time_dim=1)
Esempio n. 6
0
def main(args):
    if args.gpu:
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    conv_size = int(
        (28 - args.kernel_size + 2 * args.padding) / args.stride) + 1

    # Build network.
    network = Network()
    input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

    conv_layer = DiehlAndCookNodes(
        n=args.n_filters * conv_size * conv_size,
        shape=(args.n_filters, conv_size, conv_size),
        traces=True,
    )

    conv_conn = Conv2dConnection(
        input_layer,
        conv_layer,
        kernel_size=args.kernel_size,
        stride=args.stride,
        update_rule=PostPre,
        norm=0.4 * args.kernel_size**2,
        nu=[0, args.lr],
        reduction=max_without_indices,
        wmax=1.0,
    )

    w = torch.zeros(args.n_filters, conv_size, conv_size, args.n_filters,
                    conv_size, conv_size)
    for fltr1 in range(args.n_filters):
        for fltr2 in range(args.n_filters):
            if fltr1 != fltr2:
                for i in range(conv_size):
                    for j in range(conv_size):
                        w[fltr1, i, j, fltr2, i, j] = -100.0

    w = w.view(args.n_filters * conv_size * conv_size,
               args.n_filters * conv_size * conv_size)
    recurrent_conn = Connection(conv_layer, conv_layer, w=w)

    network.add_layer(input_layer, name="X")
    network.add_layer(conv_layer, name="Y")
    network.add_connection(conv_conn, source="X", target="Y")
    network.add_connection(recurrent_conn, source="Y", target="Y")

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers["Y"], ["v"], time=args.time)
    network.add_monitor(voltage_monitor, name="output_voltage")

    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    train_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=["s"],
                                time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    voltages = {}
    for layer in set(network.layers) - {"X"}:
        voltages[layer] = Monitor(network.layers[layer],
                                  state_vars=["v"],
                                  time=args.time)
        network.add_monitor(voltages[layer], name="%s_voltages" % layer)

    # Train the network.
    print("Begin training.\n")
    start = time()

    weights_im = None

    for epoch in range(args.n_epochs):
        if epoch % args.progress_interval == 0:
            print("Progress: %d / %d (%.4f seconds)" %
                  (epoch, args.n_epochs, time() - start))
            start = time()

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(tqdm(train_dataloader)):
            # Get next input sample.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input.
            network.run(inpts=inpts, time=args.time, input_time_dim=0)

            # Decay learning rate.
            network.connections["X", "Y"].nu[1] *= 0.99

            # Optionally plot various simulation information.
            if args.plot:
                weights = conv_conn.w
                weights_im = plot_conv2d_weights(weights, im=weights_im)

                plt.pause(1e-8)

            network.reset_()  # Reset state variables.

    print("Progress: %d / %d (%.4f seconds)\n" %
          (args.n_epochs, args.n_epochs, time() - start))
    print("Training complete.\n")
Esempio n. 7
0
def main(args):
    update_interval = args.update_steps * args.batch_size

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":
        reduction = torch.sum
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(
        n_inpt=784,
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(0.0, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU.
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    dataset, valid_dataset = torch.utils.data.random_split(
        dataset, [59000, 1000])

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10
    assignments = -torch.ones(args.n_neurons)
    proportions = torch.zeros(args.n_neurons, n_classes)
    rates = torch.zeros(args.n_neurons, n_classes)

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=["s"],
                                time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(args.log_dir):
        shutil.rmtree(args.log_dir)

    # Summary writer.
    writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60)

    for epoch in range(args.n_epochs):
        print(f"\nEpoch: {epoch}\n")

        labels = []

        # Get training data loader.
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.n_workers,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(dataloader):
            print(f"Step: {step} / {len(dataloader)}")

            global_step = 60000 * epoch + args.batch_size * step
            if step % args.update_steps == 0 and step > 0:
                # Disable learning.
                network.train(False)

                # Get test data loader.
                valid_dataloader = DataLoader(
                    dataset=valid_dataset,
                    batch_size=args.test_batch_size,
                    shuffle=True,
                    num_workers=args.n_workers,
                    pin_memory=args.gpu,
                )

                test_labels = []
                test_spike_record = torch.zeros(len(valid_dataset), args.time,
                                                args.n_neurons)
                t0 = time()
                for test_step, test_batch in enumerate(valid_dataloader):
                    # Prep next input batch.
                    inpts = {"X": test_batch["encoded_image"]}
                    if args.gpu:
                        inpts = {k: v.cuda() for k, v in inpts.items()}

                    # Run the network on the input (inference mode).
                    network.run(inpts=inpts,
                                time=args.time,
                                one_step=args.one_step)

                    # Add to spikes recording.
                    s = spikes["Y"].get("s").permute((1, 0, 2))
                    test_spike_record[(test_step * args.test_batch_size
                                       ):(test_step * args.test_batch_size) +
                                      s.size(0)] = s

                    # Plot simulation data.
                    if args.valid_plot:
                        input_exc_weights = network.connections["X", "Y"].w
                        square_weights = get_square_weights(
                            input_exc_weights.view(784, args.n_neurons),
                            n_sqrt, 28)
                        spikes_ = {
                            layer: spikes[layer].get("s")[:, 0]
                            for layer in spikes
                        }
                        spike_ims, spike_axes = plot_spikes(spikes_,
                                                            ims=spike_ims,
                                                            axes=spike_axes)
                        weights_im = plot_weights(square_weights,
                                                  im=weights_im)

                        plt.pause(1e-8)

                    # Reset state variables.
                    network.reset_()

                    test_labels.extend(test_batch["label"].tolist())

                t1 = time() - t0

                writer.add_scalar(tag="time/test",
                                  scalar_value=t1,
                                  global_step=global_step)

                # Convert the list of labels into a tensor.
                test_label_tensor = torch.tensor(test_labels)

                # Get network predictions.
                all_activity_pred = all_activity(
                    spikes=test_spike_record,
                    assignments=assignments,
                    n_labels=n_classes,
                )
                proportion_pred = proportion_weighting(
                    spikes=test_spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/valid/all vote",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long()
                         == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/valid/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/train/all vote",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/train/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                # Re-enable learning.
                network.train(True)

                labels = []

            labels.extend(batch["label"].tolist())

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input (training mode).
            t0 = time()
            network.run(inpts=inpts, time=args.time, one_step=args.one_step)
            t1 = time() - t0

            writer.add_scalar(tag="time/train/step",
                              scalar_value=t1,
                              global_step=global_step)

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_()
Esempio n. 8
0
def main(args):
    # Random seeding.
    torch.manual_seed(args.seed)

    # Device.
    device = torch.device("cuda" if args.gpu else "cpu")

    # No. workers.
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    # Build network.
    network = Network(batch_size=args.batch_size)
    network.add_layer(Input(shape=(1, 28, 28), traces=True), name="I")
    network.add_layer(LIFNodes(n=10,
                               traces=True,
                               rest=0,
                               reset=0,
                               thresh=1,
                               refrac=0),
                      name="O")
    network.add_connection(
        Connection(
            source=network.layers["I"],
            target=network.layers["O"],
            nu=(0.0, 0.01),
            update_rule=Hebbian,
            wmin=0.0,
            wmax=1.0,
            norm=100.0,
            reduction=torch.sum,
        ),
        source="I",
        target="O",
    )

    if args.plot:
        for l in network.layers:
            network.add_monitor(Monitor(network.layers[l],
                                        state_vars=("s", ),
                                        time=args.time),
                                name=l)

    network.to(device)

    # Load dataset.
    dataset = MNIST(
        image_encoder=PoissonEncoder(time=args.time, dt=1.0),
        label_encoder=None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Lambda(lambda x: x * 250)]),
    )

    # Create a dataloader to iterate and batch data
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.n_workers,
        pin_memory=args.gpu,
    )

    spike_ims = None
    spike_axes = None
    weights_im = None

    t0 = time()
    for step, batch in enumerate(tqdm(dataloader)):
        # Prep next input batch.
        inputs = batch["encoded_image"]

        inpts = {"I": inputs}
        if args.gpu:
            inpts = {k: v.cuda() for k, v in inpts.items()}

        clamp = torch.nn.functional.one_hot(batch["label"],
                                            num_classes=10).byte()
        unclamp = ~clamp
        clamp = {"O": clamp}
        unclamp = {"O": unclamp}

        # Run the network on the input.
        network.run(
            inpts=inpts,
            time=args.time,
            one_step=args.one_step,
            clamp=clamp,
            unclamp=unclamp,
        )

        if args.plot:
            # Plot output spikes.
            spikes = {
                l: network.monitors[l].get("s")[:, 0]
                for l in network.monitors
            }
            spike_ims, spike_axes = plot_spikes(spikes=spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)

            # Plot connection weights.
            weights = network.connections["I", "O"].w
            weights = get_square_weights(weights, n_sqrt=4, side=28)
            weights_im = plot_weights(weights,
                                      wmax=network.connections["I", "O"].wmax,
                                      im=weights_im)

            plt.pause(1e-2)

        # Reset state variables.
        network.reset_()

    network.learning = False

    for step, batch in enumerate(tqdm(dataloader)):
        # Prep next input batch.
        inputs = batch["encoded_image"]

        inpts = {"I": inputs}
        if args.gpu:
            inpts = {k: v.cuda() for k, v in inpts.items()}

        # Run the network on the input.
        network.run(inpts=inpts, time=args.time, one_step=args.one_step)

        if args.plot:
            # Plot output spikes.
            spikes = {
                l: network.monitors[l].get("s")[:, 0]
                for l in network.monitors
            }
            spike_ims, spike_axes = plot_spikes(spikes=spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)

            # Plot connection weights.
            weights = network.connections["I", "O"].w
            weights = get_square_weights(weights, n_sqrt=4, side=28)
            weights_im = plot_weights(weights,
                                      wmax=network.connections["I", "O"].wmax,
                                      im=weights_im)

            plt.pause(1e-2)

    t1 = time() - t0

    print(f"Time: {t1}")
Esempio n. 9
0
def main(args):
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    device = torch.device("cuda" if args.gpu else "cpu")

    # Load trained MLP from disk.
    ann = MLP().to(device)
    f = os.path.join(args.job_dir, "ann.pt")
    ann.load_state_dict(state_dict=torch.load(f=f))

    # Load dataset.
    dataset = MNIST(
        image_encoder=RepeatEncoder(time=args.time, dt=1.0),
        label_encoder=None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, )),
            transforms.Lambda(lambda x: x.view(-1)),
        ]),
    )

    # Do ANN to SNN conversion.
    data = dataset.data.float()
    data /= data.max()
    data = data.view(-1, 784)
    snn = ann_to_snn(ann, input_shape=(784, ), data=data.to(device))
    snn = snn.to(device)

    print(snn)

    # Create a dataloader to iterate and batch data
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.n_workers,
        pin_memory=args.gpu,
    )

    correct = 0
    t0 = time()
    for step, batch in enumerate(tqdm(dataloader)):
        # Prep next input batch.
        inputs = batch["encoded_image"]
        labels = batch["label"]

        inpts = {"Input": inputs}
        if args.gpu:
            inpts = {k: v.cuda() for k, v in inpts.items()}

        # Run the network on the input.
        snn.run(inpts=inpts, time=args.time, one_step=args.one_step)

        output_voltages = snn.layers["5"].summed
        prediction = torch.softmax(output_voltages, dim=1).argmax(dim=1)
        correct += (prediction.cpu() == labels).sum().item()

        # Reset state variables.
        snn.reset_()

    t1 = time() - t0

    accuracy = 100 * correct / len(dataloader.dataset)

    print(f"SNN accuracy: {accuracy:.2f}")

    path = os.path.join(ROOT_DIR, "results", args.results_file)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    if not os.path.isfile(path):
        with open(os.path.join(path), "w") as f:
            f.write(
                "seed,simulation time,batch size,inference time,accuracy\n")

    to_write = [args.seed, args.time, args.batch_size, t1, accuracy]
    to_write = ",".join(map(str, to_write)) + "\n"
    with open(os.path.join(path), "a") as f:
        f.write(to_write)

    return t1
def main(args):
    # Random seed.
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers.
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    # Build network.
    network = bindsnet.network.Network(dt=args.dt, batch_size=args.batch_size)

    # Layers.
    input_layer = Input(shape=(1, 28, 28), traces=True)
    conv1_layer = LIFNodes(shape=(20, 24, 24), traces=True)
    pool1_layer = PassThroughNodes(shape=(20, 12, 12), traces=True)
    conv2_layer = LIFNodes(shape=(50, 8, 8), traces=True)
    pool2_layer = PassThroughNodes(shape=(50, 4, 4), traces=True)
    dense_layer = LIFNodes(shape=(200, ), traces=True)
    output_layer = LIFNodes(shape=(10, ), traces=True)

    network.add_layer(input_layer, name="I")
    network.add_layer(conv1_layer, name="C1")
    network.add_layer(pool1_layer, name="P1")
    network.add_layer(conv2_layer, name="C2")
    network.add_layer(pool2_layer, name="P2")
    network.add_layer(dense_layer, name="D")
    network.add_layer(output_layer, name="O")

    # Connections.
    conv1_connection = Conv2dConnection(
        source=input_layer,
        target=conv1_layer,
        update_rule=WeightDependentPost,
        nu=(0.0, args.nu),
        kernel_size=5,
        stride=1,
        wmin=-1.0,
        wmax=1.0,
    )
    pool1_connection = SpatialPooling2dConnection(source=conv1_layer,
                                                  target=pool1_layer,
                                                  kernel_size=2,
                                                  stride=2)
    conv2_connection = Conv2dConnection(
        source=pool1_layer,
        target=conv2_layer,
        update_rule=WeightDependentPost,
        nu=(0.0, args.nu),
        kernel_size=5,
        stride=1,
        wmin=-1.0,
        wmax=1.0,
    )
    pool2_connection = SpatialPooling2dConnection(source=conv2_layer,
                                                  target=pool2_layer,
                                                  kernel_size=2,
                                                  stride=2)
    dense_connection = Connection(
        source=pool2_layer,
        target=dense_layer,
        update_rule=WeightDependentPost,
        nu=(0.0, args.nu),
        wmin=-1.0,
        wmax=1.0,
    )
    output_connection = Connection(
        source=dense_layer,
        target=output_layer,
        update_rule=WeightDependentPost,
        nu=(0.0, args.nu),
        wmin=-1.0,
        wmax=1.0,
    )

    network.add_connection(connection=conv1_connection,
                           source="I",
                           target="C1")
    network.add_connection(connection=pool1_connection,
                           source="C1",
                           target="P1")
    network.add_connection(connection=conv2_connection,
                           source="P1",
                           target="C2")
    network.add_connection(connection=pool2_connection,
                           source="C2",
                           target="P2")
    network.add_connection(connection=dense_connection,
                           source="P2",
                           target="D")
    network.add_connection(connection=output_connection,
                           source="D",
                           target="O")

    # Monitors.
    for name, layer in network.layers.items():
        monitor = Monitor(obj=layer, state_vars=("s", ), time=args.time)
        network.add_monitor(monitor=monitor, name=name)

    # Directs network to GPU.
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(bindsnet.ROOT_DIR, "data", "MNIST"),
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    spike_ims = None
    spike_axes = None
    conv1_weights_im = None
    conv2_weights_im = None
    dense_weights_im = None
    output_weights_im = None

    for epoch in range(args.n_epochs):
        # Create a dataloader to iterate over dataset.
        dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.n_workers,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(tqdm(dataloader)):
            # Prep next input batch.
            inpts = {"I": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input.
            network.run(inpts=inpts, time=args.time)

            # Plot simulation data.
            if args.plot:
                spikes = {}
                for name, monitor in network.monitors.items():
                    spikes[name] = monitor.get("s")[:, 0].view(args.time, -1)

                spike_ims, spike_axes = plot_spikes(spikes,
                                                    ims=spike_ims,
                                                    axes=spike_axes)

                conv1_weights_im = plot_conv2d_weights(conv1_connection.w,
                                                       im=conv1_weights_im,
                                                       wmin=-1.0,
                                                       wmax=1.0)
                conv2_weights_im = plot_conv2d_weights(conv2_connection.w,
                                                       im=conv2_weights_im,
                                                       wmin=-1.0,
                                                       wmax=1.0)
                dense_weights_im = plot_weights(dense_connection.w,
                                                im=dense_weights_im,
                                                wmin=-1.0,
                                                wmax=1.0)
                output_weights_im = plot_weights(output_connection.w,
                                                 im=output_weights_im,
                                                 wmin=-1.0,
                                                 wmax=1.0)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_()
def main(args):
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    device = torch.device("cuda" if args.gpu else "cpu")

    # Load trained ANN from disk.
    if args.arch == 'vgg15ab':
        ann = vgg_15_avg_before_relu(dataset=args.dataset)
    # add other architectures here#
    else:
        raise ValueError('Unknown architecture')

    ann.features = torch.nn.DataParallel(ann.features)
    ann.cuda()
    if not os.path.isdir(args.job_dir):
        os.mkdir(args.job_dir)
    f = os.path.join('.', args.model)
    try:
        dictionary = torch.load(f=f)['state_dict']
    except KeyError:
        dictionary = torch.load(f=f)
    ann.load_state_dict(state_dict=dictionary, strict=True)

    if args.dataset == 'imagenet':
        input_shape = (3, 224, 224)

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        # the actual data to be evaluated
        val_loader = ImageNet(image_encoder=RepeatEncoder(time=args.time,
                                                          dt=1.0),
                              label_encoder=None,
                              root=args.data,
                              download=False,
                              transform=transforms.Compose([
                                  transforms.Resize((256, 256)),
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(),
                                  normalize,
                              ]),
                              split='val')
        # a wrapper class
        dataloader = DataLoader(
            val_loader,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=args.gpu,
        )
        # A loader of samples for normalization of the SNN from the training set
        norm_loader = ImageNet(
            image_encoder=RepeatEncoder(time=args.time, dt=1.0),
            label_encoder=None,
            root=args.data,
            download=False,
            split='train',
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]),
        )

    elif args.dataset == 'cifar100':
        input_shape = (3, 32, 32)
        print('==> Using Pytorch CIFAR-100 Dataset')
        normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                         std=[0.267, 0.256, 0.276])
        val_loader = CIFAR100(image_encoder=RepeatEncoder(time=args.time,
                                                          dt=1.0),
                              label_encoder=None,
                              root=args.data,
                              download=True,
                              train=False,
                              transform=transforms.Compose([
                                  transforms.RandomCrop(32, padding=4),
                                  transforms.RandomHorizontalFlip(0.5),
                                  transforms.ToTensor(),
                                  normalize,
                              ]))

        dataloader = DataLoader(
            val_loader,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=0,
            pin_memory=args.gpu,
        )

        norm_loader = CIFAR100(image_encoder=RepeatEncoder(time=args.time,
                                                           dt=1.0),
                               label_encoder=None,
                               root=args.data,
                               download=True,
                               train=True,
                               transform=transforms.Compose([
                                   transforms.RandomCrop(32, padding=4),
                                   transforms.RandomHorizontalFlip(0.5),
                                   transforms.ToTensor(),
                                   normalize,
                               ]))
    else:
        raise ValueError('Unsupported dataset.')

    if args.eval_size == -1:
        args.eval_size = len(val_loader)

    for step, batch in enumerate(
            torch.utils.data.DataLoader(norm_loader, batch_size=args.norm)):
        data = batch['image']
        break

    snn = ann_to_snn(ann,
                     input_shape=input_shape,
                     data=data,
                     percentile=args.percentile)

    torch.cuda.empty_cache()
    snn = snn.to(device)

    correct = 0
    t0 = time()
    accuracies = np.zeros((args.time, (args.eval_size // args.batch_size) + 1),
                          dtype=np.float32)
    for step, batch in enumerate(tqdm(dataloader)):
        if (step + 1) * args.batch_size > args.eval_size:
            break
        # Prep next input batch.
        inputs = batch["encoded_image"]
        labels = batch["label"]
        inpts = {"Input": inputs}
        if args.gpu:
            inpts = {k: v.cuda() for k, v in inpts.items()}

        snn.run(inpts=inpts,
                time=args.time,
                step=step,
                acc=accuracies,
                labels=labels,
                one_step=args.one_step)
        last_layer = list(snn.layers.keys())[-1]
        output_voltages = snn.layers[last_layer].summed
        prediction = torch.softmax(output_voltages, dim=1).argmax(dim=1)
        correct += (prediction.cpu() == labels).sum().item()
        snn.reset_()
    t1 = time() - t0

    final = accuracies.sum(axis=1) / args.eval_size

    plt.plot(final)
    plt.suptitle('{} {} ANN-SNN@{} percentile'.format(args.dataset, args.arch,
                                                      args.percentile),
                 fontsize=20)
    plt.xlabel('Timestep', fontsize=19)
    plt.ylabel('Accuracy', fontsize=19)
    plt.grid()
    plt.show()
    plt.savefig('{}/{}_{}.png'.format(args.job_dir, args.arch,
                                      args.percentile))
    np.save(
        '{}/voltage_accuracy_{}_{}.npy'.format(args.job_dir, args.arch,
                                               args.percentile), final)

    accuracy = 100 * correct / args.eval_size

    print(f"SNN accuracy: {accuracy:.2f}")
    print(f"Clock time used: {t1:.4f} ms.")
    path = os.path.join(args.job_dir, "results", args.results_file)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    if not os.path.isfile(path):
        with open(path, "w") as f:
            f.write(
                "seed,simulation time,batch size,inference time,accuracy\n")
    to_write = [args.seed, args.time, args.batch_size, t1, accuracy]
    to_write = ",".join(map(str, to_write)) + "\n"
    with open(path, "a") as f:
        f.write(to_write)

    return t1