예제 #1
0
def main(args):
    if args.update_steps is None:
        args.update_steps = max(250 // args.batch_size, 1)

    update_interval = args.update_steps * args.batch_size

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":
        reduction = torch.sum
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(
        n_inpt=784,
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(1e-4, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10
    assignments = -torch.ones(args.n_neurons)
    proportions = torch.zeros(args.n_neurons, n_classes)
    rates = torch.zeros(args.n_neurons, n_classes)

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=["s"],
                                time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(args.log_dir):
        shutil.rmtree(args.log_dir)

    # Summary writer.
    writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60)

    for epoch in range(args.n_epochs):
        print(f"\nEpoch: {epoch}\n")

        labels = []

        # Create a dataloader to iterate and batch data
        dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.n_workers,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(dataloader):
            print(f"Step: {step}")

            global_step = 60000 * epoch + args.batch_size * step

            if step % args.update_steps == 0 and step > 0:
                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/all vote",
                    scalar_value=torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/proportion weighting",
                    scalar_value=torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="spikes/mean",
                    scalar_value=torch.mean(torch.sum(spike_record, dim=1)),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                labels = []

            labels.extend(batch["label"].tolist())

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input.
            t0 = time()
            network.run(inpts=inpts, time=args.time, one_step=args.one_step)
            t1 = time() - t0

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            writer.add_scalar(tag="time/simulation",
                              scalar_value=t1,
                              global_step=global_step)

            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_()
예제 #2
0
network.add_layer(input_layer, name="X")
network.add_layer(conv_layer, name="Y")
network.add_connection(conv_conn, source="X", target="Y")
network.add_connection(recurrent_conn, source="Y", target="Y")

# Voltage recording for excitatory and inhibitory layers.
voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time)
network.add_monitor(voltage_monitor, name="output_voltage")

if gpu:
    network.to("cuda")

# Load MNIST data.
train_dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    "../../data/MNIST",
    download=True,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x * intensity)]),
)

spikes = {}
for layer in set(network.layers):
    spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
    network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
예제 #3
0
network.add_connection(FF1a, source="I_a", target="rTNN_1")
# (Recurrences)
network.add_connection(rTNN_to_buf1, source="rTNN_1", target="BUF_1")
network.add_connection(buf1_to_rTNN, source="BUF_1", target="rTNN_1")

# End of network creation

# Monitors:
spikes = {}
for l in network.layers:
    spikes[l] = Monitor(network.layers[l], ["s"], time=num_timesteps)
    network.add_monitor(spikes[l], name="%s_spikes" % l)

# Data and initial encoding:
dataset = MNIST(
    PoissonEncoder(time=num_timesteps, dt=1),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x * intensity)]),
)

# Create a dataloader to iterate and batch data
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=0,
                                         pin_memory=False)
예제 #4
0
    def create_network(self, norm=0.5, competitive_weight=-100.):
        self.norm = norm
        self.competitive_weight = competitive_weight
        self.time_max = 30
        dt = 1
        intensity = 127.5

        self.train_dataset = MNIST(
            PoissonEncoder(time=self.time_max, dt=dt),
            None,
            "MNIST",
            download=False,
            train=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
                )
            )

        # Hyperparameters
        n_filters = 25
        kernel_size = 12
        stride = 4
        padding = 0
        conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1
        per_class = int((n_filters * conv_size * conv_size) / 10)
        tc_trace = 20.  # grid search check
        tc_decay = 20.
        thresh = -52
        refrac = 5

        wmin = 0
        wmax = 1

        # Network
        self.network = Network(learning=True)
        self.GlobalMonitor = NetworkMonitor(self.network, state_vars=('v', 's', 'w'))


        self.input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

        self.output_layer = AdaptiveLIFNodes(
            n=n_filters * conv_size * conv_size,
            shape=(n_filters, conv_size, conv_size),
            traces=True,
            thres=thresh,
            trace_tc=tc_trace,
            tc_decay=tc_decay,
            theta_plus=0.05,
            tc_theta_decay=1e6)


        self.connection_XY = LocalConnection(
            self.input_layer,
            self.output_layer,
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            update_rule=PostPre,
            norm=norm, #1/(kernel_size ** 2),#0.4 * kernel_size ** 2,  # norm constant - check
            nu=[1e-4, 1e-2],
            wmin=wmin,
            wmax=wmax)

        # competitive connections
        w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size)
        for fltr1 in range(n_filters):
            for fltr2 in range(n_filters):
                if fltr1 != fltr2:
                    # change
                    for i in range(conv_size):
                        for j in range(conv_size):
                            w[fltr1, i, j, fltr2, i, j] = competitive_weight

        self.connection_YY = Connection(self.output_layer, self.output_layer, w=w)

        self.network.add_layer(self.input_layer, name='X')
        self.network.add_layer(self.output_layer, name='Y')

        self.network.add_connection(self.connection_XY, source='X', target='Y')
        self.network.add_connection(self.connection_YY, source='Y', target='Y')

        self.network.add_monitor(self.GlobalMonitor, name='Network')

        self.spikes = {}
        for layer in set(self.network.layers):
            self.spikes[layer] = Monitor(self.network.layers[layer], state_vars=["s"], time=self.time_max)
            self.network.add_monitor(self.spikes[layer], name="%s_spikes" % layer)
            #print('GlobalMonitor.state_vars:', self.GlobalMonitor.state_vars)

        self.voltages = {}
        for layer in set(self.network.layers) - {"X"}:
            self.voltages[layer] = Monitor(self.network.layers[layer], state_vars=["v"], time=self.time_max)
            self.network.add_monitor(self.voltages[layer], name="%s_voltages" % layer)
예제 #5
0
def main(args):
    if args.gpu:
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    conv_size = int((28 - args.kernel_size + 2 * args.padding) / args.stride) + 1

    # Build network.
    network = Network()
    input_layer = Input(n=784, shape=(1, 28, 28), traces=True)

    conv_layer = DiehlAndCookNodes(
        n=args.n_filters * conv_size * conv_size,
        shape=(args.n_filters, conv_size, conv_size),
        traces=True,
    )

    conv_conn = Conv2dConnection(
        input_layer,
        conv_layer,
        kernel_size=args.kernel_size,
        stride=args.stride,
        update_rule=PostPre,
        norm=0.4 * args.kernel_size ** 2,
        nu=[0, args.lr],
        reduction=max_without_indices,
        wmax=1.0,
    )

    w = torch.zeros(
        args.n_filters, conv_size, conv_size, args.n_filters, conv_size, conv_size
    )
    for fltr1 in range(args.n_filters):
        for fltr2 in range(args.n_filters):
            if fltr1 != fltr2:
                for i in range(conv_size):
                    for j in range(conv_size):
                        w[fltr1, i, j, fltr2, i, j] = -100.0

    w = w.view(
        args.n_filters * conv_size * conv_size, args.n_filters * conv_size * conv_size
    )
    recurrent_conn = Connection(conv_layer, conv_layer, w=w)

    network.add_layer(input_layer, name="X")
    network.add_layer(conv_layer, name="Y")
    network.add_connection(conv_conn, source="X", target="Y")
    network.add_connection(recurrent_conn, source="Y", target="Y")

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers["Y"], ["v"], time=args.time)
    network.add_monitor(voltage_monitor, name="output_voltage")

    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    train_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity)]
        ),
    )

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    voltages = {}
    for layer in set(network.layers) - {"X"}:
        voltages[layer] = Monitor(
            network.layers[layer], state_vars=["v"], time=args.time
        )
        network.add_monitor(voltages[layer], name="%s_voltages" % layer)

    # Train the network.
    print("Begin training.\n")
    start = time()

    weights_im = None

    for epoch in range(args.n_epochs):
        if epoch % args.progress_interval == 0:
            print(
                "Progress: %d / %d (%.4f seconds)"
                % (epoch, args.n_epochs, time() - start)
            )
            start = time()

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=args.gpu,
        )

        variance_buffers = {}
        for k in network.connections.keys():
            variance_buffers[k] = {}
            variance_buffers[k]["prev"] = network.connections[k].w.clone()
            variance_buffers[k]["sum"] = torch.zeros_like(
                variance_buffers[k]["prev"], dtype=torch.double
            )
            variance_buffers[k]["sum_squares"] = torch.zeros_like(
                variance_buffers[k]["prev"], dtype=torch.double
            )
            variance_buffers[k]["count"] = 0

        for step, batch in enumerate(tqdm(train_dataloader)):
            # Get next input sample.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input.
            network.run(inpts=inpts, time=args.time, input_time_dim=0)

            # manually compute the total update from this run
            for k in network.connections.keys():
                cur = network.connections[k].w.clone()
                weight_update = (cur - variance_buffers[k]["prev"]).double()
                variance_buffers[k]["sum"] += weight_update
                variance_buffers[k]["sum_squares"] += weight_update * weight_update
                variance_buffers[k]["count"] += 1
                variance_buffers[k]["prev"] = cur

            if step % 1000 == 999:
                process_variance_buffers(variance_buffers)

            # Decay learning rate.
            network.connections["X", "Y"].nu[1] *= 0.99

            # Optionally plot various simulation information.
            if args.plot:
                weights = conv_conn.w
                weights_im = plot_conv2d_weights(weights, im=weights_im)

                plt.pause(1e-8)

            network.reset_()  # Reset state variables.

    print(
        "Progress: %d / %d (%.4f seconds)\n"
        % (args.n_epochs, args.n_epochs, time() - start)
    )
    print("Training complete.\n")
    process_variance_buffers(variance_buffers)