Exemple #1
0
def test(network, data, labels):
    activities = torch.zeros(len(data), RUN_TIME, len(SUBJECTS)) # data_size * run_time * classes
    true_labels = torch.from_numpy(np.array(labels))

    for index, image_batch in enumerate(tqdm(data)):
        network_input = encode_image_batch(image_batch)
        network.run(network_input, time=RUN_TIME)
        spikes = network.monitors["OUT"].get("s")
        activities[index, :, :] = spikes[-RUN_TIME:, 0]

    assignments = assign_labels(activities, true_labels, len(SUBJECTS))
    predicated_labels = all_activity(activities, assignments[0], len(SUBJECTS))
    print(classification_report(true_labels, predicated_labels))
Exemple #2
0
def update_curves(curves: Dict[str,
                               list], labels: torch.Tensor, n_classes: int,
                  **kwargs) -> Tuple[Dict[str, list], Dict[str, torch.Tensor]]:
    # language=rst
    """
    Updates accuracy curves for each classification scheme.

    :param curves: Mapping from name of classification scheme to list of accuracy evaluations.
    :param labels: One-dimensional ``torch.Tensor`` of integer data labels.
    :param n_classes: Number of data categories.
    :param kwargs: Additional keyword arguments for classification scheme evaluation functions.
    :return: Updated accuracy curves and predictions.
    """
    predictions = {}
    for scheme in curves:
        # Branch based on name of classification scheme
        if scheme == 'all':
            spike_record = kwargs['spike_record']
            assignments = kwargs['assignments']

            prediction = all_activity(spike_record, assignments, n_classes)
        elif scheme == 'proportion':
            spike_record = kwargs['spike_record']
            assignments = kwargs['assignments']
            proportions = kwargs['proportions']

            prediction = proportion_weighting(spike_record, assignments,
                                              proportions, n_classes)
        elif scheme == 'ngram':
            spike_record = kwargs['spike_record']
            ngram_scores = kwargs['ngram_scores']
            n = kwargs['n']

            prediction = ngram(spike_record, ngram_scores, n_classes, n)
        elif scheme == 'logreg':
            full_spike_record = kwargs['full_spike_record']
            logreg = kwargs['logreg']

            prediction = logreg_predict(spikes=full_spike_record,
                                        logreg=logreg)

        else:
            raise NotImplementedError

        # Compute accuracy with current classification scheme.
        predictions[scheme] = prediction
        accuracy = torch.sum(labels.long() == prediction).float() / len(labels)
        curves[scheme].append(100 * accuracy)

    return curves, predictions
Exemple #3
0
def predict(labeled_batches):
    print(f"predicting {len(labeled_batches)} batches...")

    n_samples = len(labeled_batches)
    n_classes = len(TARGETS)

    true_labels = torch.zeros(n_samples)
    activities = torch.zeros(n_samples, ENCODE_WINDOW, n_classes)

    sample_idx = 0
    for label, img_batch in tqdm(labeled_batches):
        run_sinle_batch(img_batch)
        activities[sample_idx, :, :] = get_result_activity()
        sample_idx += 1

    assignments, _, _ = assign_labels(activities, true_labels, n_classes)
    pred_labels = all_activity(activities, assignments, n_classes)
    return pred_labels
    network.add_monitor(spikes[layer], name="%s_spikes" % layer)

# Train the network.
print("Begin training.\n")

pbar = tqdm(enumerate(dataloader))
for (i, dataPoint) in pbar:
    if i > n_train:
        break
    image = dataPoint["encoded_image"]
    label = dataPoint["label"]
    pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train))

    if i % update_interval == 0 and i > 0:
        # Get network predictions.
        all_activity_pred = all_activity(spike_record, assignments, 10)
        proportion_pred = proportion_weighting(spike_record, assignments,
                                               proportions, 10)

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100 * torch.sum(label.long() == all_activity_pred).item() /
            update_interval)
        accuracy["proportion"].append(
            100 * torch.sum(label.long() == proportion_pred).item() /
            update_interval)

        print(
            "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
            % (accuracy["all"][-1], np.mean(
                accuracy["all"]), np.max(accuracy["all"])))
Exemple #5
0
def main(args):
    if args.update_steps is None:
        args.update_steps = max(
            250 // args.batch_size, 1
        )  #Its value is 16 # why is it always multiplied with step? #update_steps is how many batch to classify before updating the graphs

    update_interval = args.update_steps * args.batch_size  # Value is 240 #update_interval is how many pictures to classify before updating the graphs

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(
            args.seed
        )  #to enable reproducability of the code to get the same result
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":  #could have used switch to improve performance
        reduction = torch.sum  #weight updates for the batch
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(  #Changed here
        n_inpt=784,  # input dimensions are 28x28=784
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(1e-4, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose(  #Composes several transforms together
            [
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * args.intensity)
            ]),
    )

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10  #changed
    assignments = -torch.ones(args.n_neurons)  #assignments is set to -1
    proportions = torch.zeros(args.n_neurons,
                              n_classes)  #matrix of 100x10 filled with zeros
    rates = torch.zeros(args.n_neurons,
                        n_classes)  #matrix of 100x10 filled with zeros

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(
            network.layers[layer], state_vars=["s"], time=args.time
        )  # Monitors:  Records state variables of interest. obj:An object to record state variables from during network simulation.
        network.add_monitor(
            spikes[layer], name="%s_spikes" % layer
        )  #state_vars: Iterable of strings indicating names of state variables to record.
        #param time: If not ``None``, pre-allocate memory for state variable recording.
    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(
            args.log_dir):  #checks if the path is a existing directory
        shutil.rmtree(
            args.log_dir)  # is used to delete an entire directory tree

    # Summary writer.
    writer = SummaryWriter(
        log_dir=args.log_dir, flush_secs=60
    )  #SummaryWriter: these utilities let you log PyTorch models and metrics into a directory for visualization
    #flush_secs:  in seconds, to flush the pending events and summaries to disk.
    for epoch in range(args.n_epochs):  #default is 1
        print("\nEpoch: {epoch}\n")

        labels = []

        # Create a dataloader to iterate and batch data
        dataloader = DataLoader(  #It represents a Python iterable over a dataset
            dataset,
            batch_size=args.batch_size,  #how many samples per batch to load
            shuffle=
            True,  #set to True to have the data reshuffled at every epoch
            num_workers=args.n_workers,
            pin_memory=args.
            gpu,  #If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
        )

        for step, batch in enumerate(
                dataloader
        ):  #Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object
            print("Step:", step)

            global_step = 60000 * epoch + args.batch_size * step

            if step % args.update_steps == 0 and step > 0:

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/all vote",
                    scalar_value=torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                #Vennila: Records the accuracies in each step
                value = torch.mean(
                    (label_tensor.long() == all_activity_pred).float())
                value = value.item()
                accuracy.append(value)
                print("ACCURACY:", value)
                writer.add_scalar(
                    tag="accuracy/proportion weighting",
                    scalar_value=torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="spikes/mean",
                    scalar_value=torch.mean(torch.sum(spike_record, dim=1)),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                labels = []

            labels.extend(
                batch["label"].tolist()
            )  #for each batch or 16 pictures the labels of it is added to this list

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {
                    k: v.cuda()
                    for k, v in inpts.items()
                }  #.cuda() is used to set up and run CUDA operations in the selected GPU

            # Run the network on the input.
            t0 = time()
            network.run(inputs=inpts, time=args.time, one_step=args.one_step
                        )  # Simulate network for given inputs and time.
            t1 = time() - t0

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            writer.add_scalar(tag="time/simulation",
                              scalar_value=t1,
                              global_step=global_step)
            # if(step==1):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data = asarray(an_array)
            #     savetxt('data.csv',data)
            #     print("Beginning weights saved")
            # if(step==3749):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data2 = asarray(an_array)
            #     savetxt('data2.csv',data2)
            #     print("Ending weights saved")
            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                # print("Weights:",input_exc_weights)
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_state_variables()
        print(end_accuracy())  #Vennila
Exemple #6
0
            inputs = {"X": a}
            #print(inputs["X"].sum()/32)
        else:
            inputs = {"X": batch["encoded_image"]}
            #print(inputs["X"].sum()/32)
        if gpu:
            inputs = {k: v.cuda() for k, v in inputs.items()}

        if step % update_steps == 0 and step > 0:

            # Convert the array of labels into a tensor
            label_tensor = torch.tensor(labels)

            # Get network predictions.
            all_activity_pred = all_activity(spikes=spike_record,
                                             assignments=assignments,
                                             n_labels=n_classes)
            proportion_pred = proportion_weighting(
                spikes=spike_record,
                assignments=assignments,
                proportions=proportions,
                n_labels=n_classes,
            )

            # Compute network accuracy according to available classification strategies.
            accuracy["all"].append(
                100 *
                torch.sum(label_tensor.long() == all_activity_pred).item() /
                len(label_tensor))
            accuracy["proportion"].append(
                100 *
Exemple #7
0
voltage_axes = None
voltage_ims = None

pbar = tqdm(enumerate(dataloader_train))
for (i, datum) in pbar:
    if i > n_train:
        break

    image = datum["encoded_image"]
    label = datum["label"]
    pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train))

    #Print training accuracy
    if i % update_interval == 0 and i > 0:
        # Get network predictions.
        all_activity_pred = all_activity(spike_record, assignments,
                                         num_classes)
        proportion_pred = proportion_weighting(spike_record, assignments,
                                               proportions, num_classes)

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100 * torch.sum(labels.long() == all_activity_pred).item() /
            update_interval)
        accuracy["proportion"].append(
            100 * torch.sum(labels.long() == proportion_pred).item() /
            update_interval)

        print(
            "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
            % (accuracy["all"][-1], np.mean(
                accuracy["all"]), np.max(accuracy["all"])))
Exemple #8
0
            if i % update_interval == 0 and i > 0:
                # Get a tensor of labels
                label_tensor = torch.Tensor(labels).to(device)

                # Get network predictions.
                if use_mnist:
                    confusion = DataFrame([[0] * n_classes
                                           for _ in range(n_classes)])
                else:
                    confusion = DataFrame([[0] * n_classes
                                           for _ in range(n_classes)],
                                          columns=kws,
                                          index=kws)
                all_activity_pred = all_activity(spike_record.to('cpu'),
                                                 assignments.to('cpu'),
                                                 n_classes).to(device)
                proportion_pred = proportion_weighting(spike_record.to('cpu'),
                                                       assignments.to('cpu'),
                                                       proportions.to('cpu'),
                                                       n_classes).to(device)
                for j in range(len(label_tensor)):
                    true_idx = label_tensor[j].long().item()
                    pred_idx = all_activity_pred[j].item()
                    if use_mnist:
                        confusion[true_idx][pred_idx] += 1
                    else:
                        confusion[kws[true_idx]][kws[pred_idx]] += 1

                # Compute network accuracy
                accuracy['all'].append(
def main(args):
    update_interval = args.update_steps * args.batch_size

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":
        reduction = torch.sum
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(
        n_inpt=784,
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(0.0, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU.
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    dataset, valid_dataset = torch.utils.data.random_split(
        dataset, [59000, 1000])

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10
    assignments = -torch.ones(args.n_neurons)
    proportions = torch.zeros(args.n_neurons, n_classes)
    rates = torch.zeros(args.n_neurons, n_classes)

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=["s"],
                                time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(args.log_dir):
        shutil.rmtree(args.log_dir)

    # Summary writer.
    writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60)

    for epoch in range(args.n_epochs):
        print(f"\nEpoch: {epoch}\n")

        labels = []

        # Get training data loader.
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.n_workers,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(dataloader):
            print(f"Step: {step} / {len(dataloader)}")

            global_step = 60000 * epoch + args.batch_size * step
            if step % args.update_steps == 0 and step > 0:
                # Disable learning.
                network.train(False)

                # Get test data loader.
                valid_dataloader = DataLoader(
                    dataset=valid_dataset,
                    batch_size=args.test_batch_size,
                    shuffle=True,
                    num_workers=args.n_workers,
                    pin_memory=args.gpu,
                )

                test_labels = []
                test_spike_record = torch.zeros(len(valid_dataset), args.time,
                                                args.n_neurons)
                t0 = time()
                for test_step, test_batch in enumerate(valid_dataloader):
                    # Prep next input batch.
                    inpts = {"X": test_batch["encoded_image"]}
                    if args.gpu:
                        inpts = {k: v.cuda() for k, v in inpts.items()}

                    # Run the network on the input (inference mode).
                    network.run(inpts=inpts,
                                time=args.time,
                                one_step=args.one_step)

                    # Add to spikes recording.
                    s = spikes["Y"].get("s").permute((1, 0, 2))
                    test_spike_record[(test_step * args.test_batch_size
                                       ):(test_step * args.test_batch_size) +
                                      s.size(0)] = s

                    # Plot simulation data.
                    if args.valid_plot:
                        input_exc_weights = network.connections["X", "Y"].w
                        square_weights = get_square_weights(
                            input_exc_weights.view(784, args.n_neurons),
                            n_sqrt, 28)
                        spikes_ = {
                            layer: spikes[layer].get("s")[:, 0]
                            for layer in spikes
                        }
                        spike_ims, spike_axes = plot_spikes(spikes_,
                                                            ims=spike_ims,
                                                            axes=spike_axes)
                        weights_im = plot_weights(square_weights,
                                                  im=weights_im)

                        plt.pause(1e-8)

                    # Reset state variables.
                    network.reset_()

                    test_labels.extend(test_batch["label"].tolist())

                t1 = time() - t0

                writer.add_scalar(tag="time/test",
                                  scalar_value=t1,
                                  global_step=global_step)

                # Convert the list of labels into a tensor.
                test_label_tensor = torch.tensor(test_labels)

                # Get network predictions.
                all_activity_pred = all_activity(
                    spikes=test_spike_record,
                    assignments=assignments,
                    n_labels=n_classes,
                )
                proportion_pred = proportion_weighting(
                    spikes=test_spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/valid/all vote",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long()
                         == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/valid/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/train/all vote",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/train/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                # Re-enable learning.
                network.train(True)

                labels = []

            labels.extend(batch["label"].tolist())

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input (training mode).
            t0 = time()
            network.run(inpts=inpts, time=args.time, one_step=args.one_step)
            t1 = time() - t0

            writer.add_scalar(tag="time/train/step",
                              scalar_value=t1,
                              global_step=global_step)

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_()
Exemple #10
0
    def train(self, config=None):
        if config is None:
            cfg = self.cfg

        update_interval = cfg['update_interval']
        time = cfg['time']
        n_neurons = cfg['network']['n_neurons']
        dataset, n_classes = self._init_dataset(cfg)

        # Record spikes during the simulation
        spike_record = torch.zeros(update_interval, time, n_neurons)

        # Neuron assignments and spike proportions
        assignments = -torch.ones(n_neurons)
        proportions = torch.zeros(n_neurons, n_classes)
        rates = torch.zeros(n_neurons, n_classes)

        # Sequence of accuracy estimates
        accuracy = {"all": [], "proportion": []}

        # Set up monitors for spikes and voltages
        exc_voltage_monitor, inh_voltage_monitor, spikes, voltages = self._init_network_monitor(
            self.network, cfg)

        inpt_ims, inpt_axes = None, None
        spike_ims, spike_axes = None, None
        weights_im = None
        assigns_im = None
        perf_ax = None
        voltage_axes, voltage_ims = None, None

        print("\nBegin training.\n")
        iteration = 0
        for epoch in range(cfg['epochs']):
            print("Progress: %d / %d" % (epoch, cfg['epochs']))
            labels = []
            start_time = T.time()

            dataloader = DataLoader(dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=cfg['n_workers'])

            for step, batch in enumerate(tqdm(dataloader)):
                # Get next input sample.
                inputs = {'X': batch["encoded_image"].view(time, 1, 1, 28, 28)}
                inputs = {k: v.to(self.device) for k, v in inputs.items()}

                if step % update_interval == 0 and step > 0:
                    # Convert the array of labels into a tensor
                    label_tensor = torch.tensor(labels)

                    # Get network predictions.
                    all_activity_pred = all_activity(spikes=spike_record,
                                                     assignments=assignments,
                                                     n_labels=n_classes)
                    proportion_pred = proportion_weighting(
                        spikes=spike_record,
                        assignments=assignments,
                        proportions=proportions,
                        n_labels=n_classes,
                    )

                    # Compute network accuracy according to available classification strategies.
                    accuracy["all"].append(100 * torch.sum(
                        label_tensor.long() == all_activity_pred).item() /
                                           len(label_tensor))
                    accuracy["proportion"].append(100 * torch.sum(
                        label_tensor.long() == proportion_pred).item() /
                                                  len(label_tensor))

                    iteration += len(label_tensor)

                    print(
                        "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
                        % (
                            accuracy["all"][-1],
                            np.mean(accuracy["all"]),
                            np.max(accuracy["all"]),
                        ))
                    print(
                        "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n"
                        % (
                            accuracy["proportion"][-1],
                            np.mean(accuracy["proportion"]),
                            np.max(accuracy["proportion"]),
                        ))

                    self.recorder.insert(
                        (iteration, accuracy["all"][-1],
                         np.mean(accuracy["all"]), np.max(accuracy["all"]),
                         accuracy["proportion"][-1],
                         np.mean(accuracy["proportion"]),
                         np.max(accuracy["proportion"])))

                    assignments, proportions, rates = assign_labels(
                        spikes=spike_record,
                        labels=label_tensor,
                        n_labels=n_classes,
                        rates=rates,
                    )

                    labels = []

                labels.append(batch["label"])

                # Run the network on the input.
                self.network.run(inputs=inputs, time=time, input_time_dim=1)

                # Get voltage recording.
                exc_voltages = exc_voltage_monitor.get("v")
                inh_voltages = inh_voltage_monitor.get("v")

                # Add to spikes recording.
                spike_record[step % update_interval] = spikes["Ae"].get(
                    "s").squeeze()

                # Reset state variables
                self.network.reset_state_variables()

                if step % 1000 == 0:
                    self.save(cfg=cfg)

            print("Progress: %d / %d (%.4f seconds)" %
                  (epoch + 1, cfg['epochs'], T.time() - start_time))

        self.recorder.write(self.save_dir, cfg['name'])
        print("Training complete.\n")
        return None
Exemple #11
0
perf_ax = None
voltage_axes = None
voltage_ims = None

pbar = tqdm(enumerate(dataloader))
for (i, datum) in pbar:
    if i > n_train:
        break

    image = datum["encoded_image"]
    label = datum["label"]
    pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train))

    if i % update_interval == 0 and i > 0:
        # Get network predictions.
        all_activity_pred = all_activity(spike_record, assignments, 10)
        proportion_pred = proportion_weighting(
            spike_record, assignments, proportions, 10
        )

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100 * torch.sum(labels.long() == all_activity_pred).item() / update_interval
        )
        accuracy["proportion"].append(
            100 * torch.sum(labels.long() == proportion_pred).item() / update_interval
        )

        print(
            "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
            % (accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]))
Exemple #12
0
def main():
    seed = 0  #random seed
    n_neurons = 100  # number of neurons per layer
    n_train = 60000  # number of traning examples to go through
    n_epochs = 1
    inh = 120.0  # strength of synapses from inh. layer to exci. layer
    exc = 22.5
    lr = 1e-2  # learning rate
    lr_decay = 0.99  # learning rate decay
    time = 350  # duration of each sample after running through possion encoder
    dt = 1  # timestep
    theta_plus = 0.05  # post spike threshold increase
    tc_theta_decay = 1e7  # threshold decay
    intensity = 0.25  # number to multiply input Diehl Cook maja 0.25
    progress_interval = 10
    update_interval = 250
    plot = False
    gpu = False
    load_network = False  # load network from disk
    n_classes = 10
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    # TRAINING
    save_weights_fn = "plots_snn/weights/weights_train.png"
    save_performance_fn = "plots_snn/performance/performance_train.png"
    save_assaiments_fn = "plots_snn/assaiments/assaiments_train.png"
    directorys = [
        "plots_snn", "plots_snn/weights", "plots_snn/performance",
        "plots_snn/assaiments"
    ]
    for directory in directorys:
        if not os.path.exists(directory):
            os.makedirs(directory)
    assert n_train % update_interval == 0
    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    # Build network
    if load_network:
        network = load('net_output.pt')  # here goes file with network to load
    else:
        network = DiehlAndCook2015(
            n_inpt=784,
            n_neurons=n_neurons,
            exc=exc,
            inh=inh,
            dt=dt,
            norm=78.4,
            nu=(1e-4, lr),
            theta_plus=theta_plus,
            inpt_shape=(1, 28, 28),
        )
    if gpu:
        network.to("cuda")
    # Pull dataset
    data, targets = torch.load(
        'data/MNIST/TorchvisionDatasetWrapper/processed/training.pt')
    data = data * intensity
    trainset = torch.utils.data.TensorDataset(data, targets)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=1)

    # Spike recording
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_train, n_neurons).long()

    # Intialization
    if load_network:
        assignments, proportions, rates, ngram_scores = torch.load(
            'parameter_output.pt')
    else:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}
    best_accuracy = 0

    # Initilize spike records
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)
    i = 0
    current_labels = torch.zeros(update_interval)
    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None
    # train
    train_time = t.time()

    current_labels = torch.zeros(update_interval)
    time1 = t.time()
    for j in range(n_epochs):
        i = 0
        for sample, label in trainloader:
            if i >= n_train:
                break
            if i % progress_interval == 0:
                print(f'Progress: {i} / {n_train} took {(t.time()-time1)} s')
                time1 = t.time()
            if i % update_interval == 0 and i > 0:
                #network.connections['X','Y'].update_rule.nu[1] *= lr_decay
                curves, preds = update_curves(curves,
                                              current_labels,
                                              n_classes,
                                              spike_record=spike_record,
                                              assignments=assignments,
                                              proportions=proportions,
                                              ngram_scores=ngram_scores,
                                              n=2)
                print_results(curves)
                for scheme in preds:
                    predictions[scheme] = torch.cat(
                        [predictions[scheme], preds[scheme]], -1)
                # Accuracy curves
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network and parameters to disk.
                    network.save(os.path.join('net_output.pt'))
                    path = "parameters_output.pt"
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)
            sample_enc = poisson(datum=sample, time=time, dt=dt)
            inpts = {'X': sample_enc}
            # Run the network on the input.
            network.run(inputs=inpts, time=time)
            retries = 0
            # Spikes reocrding
            spike_record[i % update_interval] = spikes['Ae'].get('s').view(
                time, n_neurons)
            full_spike_record[i] = spikes['Ae'].get('s').view(
                time, n_neurons).sum(0).long()
            if plot:
                _input = sample.view(28, 28)
                reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
                _spikes = {layer: spikes[layer].get('s') for layer in spikes}
                input_exc_weights = network.connections[('X', 'Ae')].w
                square_assignments = get_square_assignments(
                    assignments, n_sqrt)

                assigns_im = plot_assignments(square_assignments,
                                              im=assigns_im)
                if i % update_interval == 0:
                    square_weights = get_square_weights(
                        input_exc_weights.view(784, n_neurons), n_sqrt, 28)
                    weights_im = plot_weights(square_weights, im=weights_im)
                    [weights_im,
                     save_weights_fn] = plot_weights(square_weights,
                                                     im=weights_im,
                                                     save=save_weights_fn)
                inpt_axes, inpt_ims = plot_input(_input,
                                                 reconstruction,
                                                 label=label,
                                                 axes=inpt_axes,
                                                 ims=inpt_ims)
                spike_ims, spike_axes = plot_spikes(_spikes,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                assigns_im = plot_assignments(square_assignments,
                                              im=assigns_im,
                                              save=save_assaiments_fn)
                perf_ax = plot_performance(curves,
                                           ax=perf_ax,
                                           save=save_performance_fn)
                plt.pause(1e-8)
            current_labels[i % update_interval] = label[0]
            network.reset_state_variables()
            if i % 10 == 0 and i > 0:
                preds = all_activity(
                    spike_record[i % update_interval - 10:i % update_interval],
                    assignments, n_classes)
                print(f'Predictions: {(preds * 1.0).numpy()}')
                print(
                    f'True value:  {current_labels[i % update_interval - 10:i % update_interval].numpy()}'
                )
            i += 1

        print(f'Number of epochs {j}/{n_epochs+1}')
        torch.save(network.state_dict(), 'net_final.pt')
        path = "parameters_final.pt"
        torch.save((assignments, proportions, rates, ngram_scores),
                   open(path, 'wb'))
    print("Training completed. Training took " +
          str((t.time() - train_time) / 6) + " min.")
    print("Saving network...")
    network.save(os.path.join('net_final.pt'))
    torch.save((assignments, proportions, rates, ngram_scores),
               open('parameters_final.pt', 'wb'))
    print("Network saved.")