Esempio n. 1
0
def update_curves(curves: Dict[str,
                               list], labels: torch.Tensor, n_classes: int,
                  **kwargs) -> Tuple[Dict[str, list], Dict[str, torch.Tensor]]:
    # language=rst
    """
    Updates accuracy curves for each classification scheme.

    :param curves: Mapping from name of classification scheme to list of accuracy evaluations.
    :param labels: One-dimensional ``torch.Tensor`` of integer data labels.
    :param n_classes: Number of data categories.
    :param kwargs: Additional keyword arguments for classification scheme evaluation functions.
    :return: Updated accuracy curves and predictions.
    """
    predictions = {}
    for scheme in curves:
        # Branch based on name of classification scheme
        if scheme == 'all':
            spike_record = kwargs['spike_record']
            assignments = kwargs['assignments']

            prediction = all_activity(spike_record, assignments, n_classes)
        elif scheme == 'proportion':
            spike_record = kwargs['spike_record']
            assignments = kwargs['assignments']
            proportions = kwargs['proportions']

            prediction = proportion_weighting(spike_record, assignments,
                                              proportions, n_classes)
        elif scheme == 'ngram':
            spike_record = kwargs['spike_record']
            ngram_scores = kwargs['ngram_scores']
            n = kwargs['n']

            prediction = ngram(spike_record, ngram_scores, n_classes, n)
        elif scheme == 'logreg':
            full_spike_record = kwargs['full_spike_record']
            logreg = kwargs['logreg']

            prediction = logreg_predict(spikes=full_spike_record,
                                        logreg=logreg)

        else:
            raise NotImplementedError

        # Compute accuracy with current classification scheme.
        predictions[scheme] = prediction
        accuracy = torch.sum(labels.long() == prediction).float() / len(labels)
        curves[scheme].append(100 * accuracy)

    return curves, predictions
Esempio n. 2
0
# Train the network.
print("Begin training.\n")

pbar = tqdm(enumerate(dataloader))
for (i, dataPoint) in pbar:
    if i > n_train:
        break
    image = dataPoint["encoded_image"]
    label = dataPoint["label"]
    pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train))

    if i % update_interval == 0 and i > 0:
        # Get network predictions.
        all_activity_pred = all_activity(spike_record, assignments, 10)
        proportion_pred = proportion_weighting(spike_record, assignments,
                                               proportions, 10)

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100 * torch.sum(label.long() == all_activity_pred).item() /
            update_interval)
        accuracy["proportion"].append(
            100 * torch.sum(label.long() == proportion_pred).item() /
            update_interval)

        print(
            "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
            % (accuracy["all"][-1], np.mean(
                accuracy["all"]), np.max(accuracy["all"])))
        print(
            "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n"
Esempio n. 3
0
def main(args):
    if args.update_steps is None:
        args.update_steps = max(
            250 // args.batch_size, 1
        )  #Its value is 16 # why is it always multiplied with step? #update_steps is how many batch to classify before updating the graphs

    update_interval = args.update_steps * args.batch_size  # Value is 240 #update_interval is how many pictures to classify before updating the graphs

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(
            args.seed
        )  #to enable reproducability of the code to get the same result
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":  #could have used switch to improve performance
        reduction = torch.sum  #weight updates for the batch
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(  #Changed here
        n_inpt=784,  # input dimensions are 28x28=784
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(1e-4, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose(  #Composes several transforms together
            [
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * args.intensity)
            ]),
    )

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10  #changed
    assignments = -torch.ones(args.n_neurons)  #assignments is set to -1
    proportions = torch.zeros(args.n_neurons,
                              n_classes)  #matrix of 100x10 filled with zeros
    rates = torch.zeros(args.n_neurons,
                        n_classes)  #matrix of 100x10 filled with zeros

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(
            network.layers[layer], state_vars=["s"], time=args.time
        )  # Monitors:  Records state variables of interest. obj:An object to record state variables from during network simulation.
        network.add_monitor(
            spikes[layer], name="%s_spikes" % layer
        )  #state_vars: Iterable of strings indicating names of state variables to record.
        #param time: If not ``None``, pre-allocate memory for state variable recording.
    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(
            args.log_dir):  #checks if the path is a existing directory
        shutil.rmtree(
            args.log_dir)  # is used to delete an entire directory tree

    # Summary writer.
    writer = SummaryWriter(
        log_dir=args.log_dir, flush_secs=60
    )  #SummaryWriter: these utilities let you log PyTorch models and metrics into a directory for visualization
    #flush_secs:  in seconds, to flush the pending events and summaries to disk.
    for epoch in range(args.n_epochs):  #default is 1
        print("\nEpoch: {epoch}\n")

        labels = []

        # Create a dataloader to iterate and batch data
        dataloader = DataLoader(  #It represents a Python iterable over a dataset
            dataset,
            batch_size=args.batch_size,  #how many samples per batch to load
            shuffle=
            True,  #set to True to have the data reshuffled at every epoch
            num_workers=args.n_workers,
            pin_memory=args.
            gpu,  #If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
        )

        for step, batch in enumerate(
                dataloader
        ):  #Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object
            print("Step:", step)

            global_step = 60000 * epoch + args.batch_size * step

            if step % args.update_steps == 0 and step > 0:

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/all vote",
                    scalar_value=torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                #Vennila: Records the accuracies in each step
                value = torch.mean(
                    (label_tensor.long() == all_activity_pred).float())
                value = value.item()
                accuracy.append(value)
                print("ACCURACY:", value)
                writer.add_scalar(
                    tag="accuracy/proportion weighting",
                    scalar_value=torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="spikes/mean",
                    scalar_value=torch.mean(torch.sum(spike_record, dim=1)),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                labels = []

            labels.extend(
                batch["label"].tolist()
            )  #for each batch or 16 pictures the labels of it is added to this list

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {
                    k: v.cuda()
                    for k, v in inpts.items()
                }  #.cuda() is used to set up and run CUDA operations in the selected GPU

            # Run the network on the input.
            t0 = time()
            network.run(inputs=inpts, time=args.time, one_step=args.one_step
                        )  # Simulate network for given inputs and time.
            t1 = time() - t0

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            writer.add_scalar(tag="time/simulation",
                              scalar_value=t1,
                              global_step=global_step)
            # if(step==1):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data = asarray(an_array)
            #     savetxt('data.csv',data)
            #     print("Beginning weights saved")
            # if(step==3749):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data2 = asarray(an_array)
            #     savetxt('data2.csv',data2)
            #     print("Ending weights saved")
            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                # print("Weights:",input_exc_weights)
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_state_variables()
        print(end_accuracy())  #Vennila
Esempio n. 4
0
            #print(inputs["X"].sum()/32)
        if gpu:
            inputs = {k: v.cuda() for k, v in inputs.items()}

        if step % update_steps == 0 and step > 0:

            # Convert the array of labels into a tensor
            label_tensor = torch.tensor(labels)

            # Get network predictions.
            all_activity_pred = all_activity(spikes=spike_record,
                                             assignments=assignments,
                                             n_labels=n_classes)
            proportion_pred = proportion_weighting(
                spikes=spike_record,
                assignments=assignments,
                proportions=proportions,
                n_labels=n_classes,
            )

            # Compute network accuracy according to available classification strategies.
            accuracy["all"].append(
                100 *
                torch.sum(label_tensor.long() == all_activity_pred).item() /
                len(label_tensor))
            accuracy["proportion"].append(
                100 *
                torch.sum(label_tensor.long() == proportion_pred).item() /
                len(label_tensor))

            print(
                "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
Esempio n. 5
0
pbar = tqdm(enumerate(dataloader_train))
for (i, datum) in pbar:
    if i > n_train:
        break

    image = datum["encoded_image"]
    label = datum["label"]
    pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train))

    #Print training accuracy
    if i % update_interval == 0 and i > 0:
        # Get network predictions.
        all_activity_pred = all_activity(spike_record, assignments,
                                         num_classes)
        proportion_pred = proportion_weighting(spike_record, assignments,
                                               proportions, num_classes)

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100 * torch.sum(labels.long() == all_activity_pred).item() /
            update_interval)
        accuracy["proportion"].append(
            100 * torch.sum(labels.long() == proportion_pred).item() /
            update_interval)

        print(
            "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
            % (accuracy["all"][-1], np.mean(
                accuracy["all"]), np.max(accuracy["all"])))
        print(
            "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n"
Esempio n. 6
0
                label_tensor = torch.Tensor(labels).to(device)

                # Get network predictions.
                if use_mnist:
                    confusion = DataFrame([[0] * n_classes
                                           for _ in range(n_classes)])
                else:
                    confusion = DataFrame([[0] * n_classes
                                           for _ in range(n_classes)],
                                          columns=kws,
                                          index=kws)
                all_activity_pred = all_activity(spike_record.to('cpu'),
                                                 assignments.to('cpu'),
                                                 n_classes).to(device)
                proportion_pred = proportion_weighting(spike_record.to('cpu'),
                                                       assignments.to('cpu'),
                                                       proportions.to('cpu'),
                                                       n_classes).to(device)
                for j in range(len(label_tensor)):
                    true_idx = label_tensor[j].long().item()
                    pred_idx = all_activity_pred[j].item()
                    if use_mnist:
                        confusion[true_idx][pred_idx] += 1
                    else:
                        confusion[kws[true_idx]][kws[pred_idx]] += 1

                # Compute network accuracy
                accuracy['all'].append(
                    100 * \
                    torch.sum(label_tensor.long() == all_activity_pred).item() \
                    / update_interval
                )
Esempio n. 7
0
def main(args):
    update_interval = args.update_steps * args.batch_size

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":
        reduction = torch.sum
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(
        n_inpt=784,
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(0.0, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU.
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    dataset, valid_dataset = torch.utils.data.random_split(
        dataset, [59000, 1000])

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10
    assignments = -torch.ones(args.n_neurons)
    proportions = torch.zeros(args.n_neurons, n_classes)
    rates = torch.zeros(args.n_neurons, n_classes)

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=["s"],
                                time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(args.log_dir):
        shutil.rmtree(args.log_dir)

    # Summary writer.
    writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60)

    for epoch in range(args.n_epochs):
        print(f"\nEpoch: {epoch}\n")

        labels = []

        # Get training data loader.
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.n_workers,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(dataloader):
            print(f"Step: {step} / {len(dataloader)}")

            global_step = 60000 * epoch + args.batch_size * step
            if step % args.update_steps == 0 and step > 0:
                # Disable learning.
                network.train(False)

                # Get test data loader.
                valid_dataloader = DataLoader(
                    dataset=valid_dataset,
                    batch_size=args.test_batch_size,
                    shuffle=True,
                    num_workers=args.n_workers,
                    pin_memory=args.gpu,
                )

                test_labels = []
                test_spike_record = torch.zeros(len(valid_dataset), args.time,
                                                args.n_neurons)
                t0 = time()
                for test_step, test_batch in enumerate(valid_dataloader):
                    # Prep next input batch.
                    inpts = {"X": test_batch["encoded_image"]}
                    if args.gpu:
                        inpts = {k: v.cuda() for k, v in inpts.items()}

                    # Run the network on the input (inference mode).
                    network.run(inpts=inpts,
                                time=args.time,
                                one_step=args.one_step)

                    # Add to spikes recording.
                    s = spikes["Y"].get("s").permute((1, 0, 2))
                    test_spike_record[(test_step * args.test_batch_size
                                       ):(test_step * args.test_batch_size) +
                                      s.size(0)] = s

                    # Plot simulation data.
                    if args.valid_plot:
                        input_exc_weights = network.connections["X", "Y"].w
                        square_weights = get_square_weights(
                            input_exc_weights.view(784, args.n_neurons),
                            n_sqrt, 28)
                        spikes_ = {
                            layer: spikes[layer].get("s")[:, 0]
                            for layer in spikes
                        }
                        spike_ims, spike_axes = plot_spikes(spikes_,
                                                            ims=spike_ims,
                                                            axes=spike_axes)
                        weights_im = plot_weights(square_weights,
                                                  im=weights_im)

                        plt.pause(1e-8)

                    # Reset state variables.
                    network.reset_()

                    test_labels.extend(test_batch["label"].tolist())

                t1 = time() - t0

                writer.add_scalar(tag="time/test",
                                  scalar_value=t1,
                                  global_step=global_step)

                # Convert the list of labels into a tensor.
                test_label_tensor = torch.tensor(test_labels)

                # Get network predictions.
                all_activity_pred = all_activity(
                    spikes=test_spike_record,
                    assignments=assignments,
                    n_labels=n_classes,
                )
                proportion_pred = proportion_weighting(
                    spikes=test_spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/valid/all vote",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long()
                         == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/valid/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/train/all vote",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/train/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                # Re-enable learning.
                network.train(True)

                labels = []

            labels.extend(batch["label"].tolist())

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input (training mode).
            t0 = time()
            network.run(inpts=inpts, time=args.time, one_step=args.one_step)
            t1 = time() - t0

            writer.add_scalar(tag="time/train/step",
                              scalar_value=t1,
                              global_step=global_step)

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_()
Esempio n. 8
0
    def train(self, config=None):
        if config is None:
            cfg = self.cfg

        update_interval = cfg['update_interval']
        time = cfg['time']
        n_neurons = cfg['network']['n_neurons']
        dataset, n_classes = self._init_dataset(cfg)

        # Record spikes during the simulation
        spike_record = torch.zeros(update_interval, time, n_neurons)

        # Neuron assignments and spike proportions
        assignments = -torch.ones(n_neurons)
        proportions = torch.zeros(n_neurons, n_classes)
        rates = torch.zeros(n_neurons, n_classes)

        # Sequence of accuracy estimates
        accuracy = {"all": [], "proportion": []}

        # Set up monitors for spikes and voltages
        exc_voltage_monitor, inh_voltage_monitor, spikes, voltages = self._init_network_monitor(
            self.network, cfg)

        inpt_ims, inpt_axes = None, None
        spike_ims, spike_axes = None, None
        weights_im = None
        assigns_im = None
        perf_ax = None
        voltage_axes, voltage_ims = None, None

        print("\nBegin training.\n")
        iteration = 0
        for epoch in range(cfg['epochs']):
            print("Progress: %d / %d" % (epoch, cfg['epochs']))
            labels = []
            start_time = T.time()

            dataloader = DataLoader(dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=cfg['n_workers'])

            for step, batch in enumerate(tqdm(dataloader)):
                # Get next input sample.
                inputs = {'X': batch["encoded_image"].view(time, 1, 1, 28, 28)}
                inputs = {k: v.to(self.device) for k, v in inputs.items()}

                if step % update_interval == 0 and step > 0:
                    # Convert the array of labels into a tensor
                    label_tensor = torch.tensor(labels)

                    # Get network predictions.
                    all_activity_pred = all_activity(spikes=spike_record,
                                                     assignments=assignments,
                                                     n_labels=n_classes)
                    proportion_pred = proportion_weighting(
                        spikes=spike_record,
                        assignments=assignments,
                        proportions=proportions,
                        n_labels=n_classes,
                    )

                    # Compute network accuracy according to available classification strategies.
                    accuracy["all"].append(100 * torch.sum(
                        label_tensor.long() == all_activity_pred).item() /
                                           len(label_tensor))
                    accuracy["proportion"].append(100 * torch.sum(
                        label_tensor.long() == proportion_pred).item() /
                                                  len(label_tensor))

                    iteration += len(label_tensor)

                    print(
                        "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
                        % (
                            accuracy["all"][-1],
                            np.mean(accuracy["all"]),
                            np.max(accuracy["all"]),
                        ))
                    print(
                        "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n"
                        % (
                            accuracy["proportion"][-1],
                            np.mean(accuracy["proportion"]),
                            np.max(accuracy["proportion"]),
                        ))

                    self.recorder.insert(
                        (iteration, accuracy["all"][-1],
                         np.mean(accuracy["all"]), np.max(accuracy["all"]),
                         accuracy["proportion"][-1],
                         np.mean(accuracy["proportion"]),
                         np.max(accuracy["proportion"])))

                    assignments, proportions, rates = assign_labels(
                        spikes=spike_record,
                        labels=label_tensor,
                        n_labels=n_classes,
                        rates=rates,
                    )

                    labels = []

                labels.append(batch["label"])

                # Run the network on the input.
                self.network.run(inputs=inputs, time=time, input_time_dim=1)

                # Get voltage recording.
                exc_voltages = exc_voltage_monitor.get("v")
                inh_voltages = inh_voltage_monitor.get("v")

                # Add to spikes recording.
                spike_record[step % update_interval] = spikes["Ae"].get(
                    "s").squeeze()

                # Reset state variables
                self.network.reset_state_variables()

                if step % 1000 == 0:
                    self.save(cfg=cfg)

            print("Progress: %d / %d (%.4f seconds)" %
                  (epoch + 1, cfg['epochs'], T.time() - start_time))

        self.recorder.write(self.save_dir, cfg['name'])
        print("Training complete.\n")
        return None
Esempio n. 9
0
voltage_ims = None

pbar = tqdm(enumerate(dataloader))
for (i, datum) in pbar:
    if i > n_train:
        break

    image = datum["encoded_image"]
    label = datum["label"]
    pbar.set_description_str("Train progress: (%d / %d)" % (i, n_train))

    if i % update_interval == 0 and i > 0:
        # Get network predictions.
        all_activity_pred = all_activity(spike_record, assignments, 10)
        proportion_pred = proportion_weighting(
            spike_record, assignments, proportions, 10
        )

        # Compute network accuracy according to available classification strategies.
        accuracy["all"].append(
            100 * torch.sum(labels.long() == all_activity_pred).item() / update_interval
        )
        accuracy["proportion"].append(
            100 * torch.sum(labels.long() == proportion_pred).item() / update_interval
        )

        print(
            "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
            % (accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]))
        )
        print(