Esempio n. 1
0
    def __init__(self, network: Network, **kwargs) -> None:
        # language=rst
        """
        Initializes the pipeline.

        :param network: Arbitrary network object, will be managed by the
            ``BasePipeline`` class.

        Keyword arguments:

        :param int save_interval: How often to save the network to disk.
        :param str save_dir: Directory to save network object to.
        :param Dict[str, Any] plot_config: Dict containing the plot configuration.
            Includes length, type (``"color"`` or ``"line"``), and interval per plot
            type.
        :param int print_interval: Interval to print text output.
        :param bool allow_gpu: Allows automatic transfer to the GPU.
        """
        self.network = network

        # Network saving handles caching of intermediate results.
        self.save_dir = kwargs.get("save_dir", "network.pt")
        self.save_interval = kwargs.get("save_interval", None)

        # Handles plotting of all layer spikes and voltages.
        # This constructs monitors at every level.
        self.plot_config = kwargs.get("plot_config", {
            "data_step": True,
            "data_length": 100
        })

        if self.plot_config["data_step"] is not None:
            for l in self.network.layers:
                self.network.add_monitor(
                    Monitor(self.network.layers[l], "s",
                            self.plot_config["data_length"]),
                    name=f"{l}_spikes",
                )
                if hasattr(self.network.layers[l], "v"):
                    self.network.add_monitor(
                        Monitor(self.network.layers[l], "v",
                                self.plot_config["data_length"]),
                        name=f"{l}_voltages",
                    )

        self.print_interval = kwargs.get("print_interval", None)
        self.test_interval = kwargs.get("test_interval", None)
        self.step_count = 0
        self.init_fn()
        self.clock = time.time()
        self.allow_gpu = kwargs.get("allow_gpu", True)

        if torch.cuda.is_available() and self.allow_gpu:
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        self.network.to(self.device)
Esempio n. 2
0
    def __init__(self, network, environment, encoding, action_function, output):
        self.network = network
        self.env = environment
        if encoding == 'bernoulli':
            self.encoding = bernoulli
        self.action_function = action_function
        self.output = output
        
        # settings
        self.print_interval = 1
        self.save_interval = 1
        self.save_dir = 'network.pt'
        self.plot_length = 1.0
        self.plot_interval = 1
        self.history_length = 1

        # time
        self.time = 100
        self.dt = network.dt
        self.timestep = int(self.time / self.dt)

        # variables to use in this pipeline
        self.episode = 0
        self.iteration = 0
        self.accumulated_reward = 0
        self.reward_list = []

        # for plot
        self.plot_type = "color"
        self.s_ims, self.s_axes = None, None
        self.v_ims, self.v_axes = None, None
        self.obs_im, self.obs_ax = None, None
        self.reward_im, self.reward_ax = None, None
        
        self.obs = None
        self.reward = None
        self.done = None
        self.action_name = None

        # add monitor into network
        for l in self.network.layers:
            self.network.add_monitor(Monitor(self.network.layers[l], 's', int(self.plot_length * self.plot_interval * self.timestep)),
                                        name='{:}_spikes'.format(l))
            if 'v' in self.network.layers[l].__dict__:
                self.network.add_monitor(Monitor(self.network.layers[l], 'v', int(self.plot_length * self.plot_interval * self.timestep)),
                                            name='{:}_voltages'.format(l))
        self.spike_record = {l: torch.Tensor().byte() for l in self.network.layers}
        self.set_spike_data()

        # Set up for multiple layers of input layers.
        self.encoded = {
            name: torch.Tensor() for name, layer in network.layers.items() if isinstance(layer, AbstractInput)
        }

        self.clock = time.time()
Esempio n. 3
0
    def __init__(self,
                 name,
                 model_type,
                 model_config,
                 preprocess_config,
                 policy=None,
                 device=None,
                 **kwargs):
        super(PongAgent, self).__init__(name, 'Pong-v0')

        self.device = init_torch_device(device)

        self.valid_action = kwargs.get('valid_action', None)
        self.action = self._init_action(self.valid_action)

        self.model = self._init_model(model_type, model_config, policy)
        self.transform = Transform(preprocess_config, self.device)

        self.model_type = model_type
        self.sim_time = kwargs.get('sim_time', None)
        self.output = kwargs.get('output', None)
        self.model_config = model_config
        self.preprocess_config = preprocess_config
        self.policy = policy
        self.memory = None
        self.note = None  # episode nums

        if self.model_type.lower() == 'snn' and self.output is not None:
            self.model.add_monitor(
                Monitor(self.model.layers[self.output], ["s"]), self.output)

            self.spike_record = {
                self.output: torch.zeros((self.time, len(self.agent.action)))
            }
Esempio n. 4
0
    def test_add_objects(self):
        network = Network(dt=1.0, learning=False)

        inpt = Input(100)
        network.add_layer(inpt, name='X')
        lif = LIFNodes(50)
        network.add_layer(lif, name='Y')

        assert inpt == network.layers['X']
        assert lif == network.layers['Y']

        conn = Connection(inpt, lif)
        network.add_connection(conn, source='X', target='Y')

        assert conn == network.connections[('X', 'Y')]

        monitor = Monitor(lif, state_vars=['s', 'v'])
        network.add_monitor(monitor, 'Y')

        assert monitor == network.monitors['Y']

        network.save('net.pt')
        _network = load('net.pt', learning=True)
        assert _network.learning
        assert 'X' in _network.layers
        assert 'Y' in _network.layers
        assert ('X', 'Y') in _network.connections
        assert 'Y' in _network.monitors
        del _network

        os.remove('net.pt')
Esempio n. 5
0
    def test_add_objects(self):
        network = Network(dt=1.0, learning=False)

        inpt = Input(100)
        network.add_layer(inpt, name="X")
        lif = LIFNodes(50)
        network.add_layer(lif, name="Y")

        assert inpt == network.layers["X"]
        assert lif == network.layers["Y"]

        conn = Connection(inpt, lif)
        network.add_connection(conn, source="X", target="Y")

        assert conn == network.connections[("X", "Y")]

        monitor = Monitor(lif, state_vars=["s", "v"])
        network.add_monitor(monitor, "Y")

        assert monitor == network.monitors["Y"]

        network.save("net.pt")
        _network = load("net.pt", learning=True)
        assert _network.learning
        assert "X" in _network.layers
        assert "Y" in _network.layers
        assert ("X", "Y") in _network.connections
        assert "Y" in _network.monitors
        del _network

        os.remove("net.pt")
Esempio n. 6
0
def create_hmax(network):
    for size in FILTER_SIZES:
        s1 = Input(shape=(FILTER_TYPES, IMAGE_SIZE, IMAGE_SIZE), traces=True)
        network.add_layer(layer=s1, name=get_s1_name(size))
        # network.add_monitor(Monitor(s1, ["s"]), get_s1_name(size))

        c1 = LIFNodes(shape=(FILTER_TYPES, IMAGE_SIZE // 2, IMAGE_SIZE // 2), thresh=-64, traces=True)
        network.add_layer(layer=c1, name=get_c1_name(size))
        # network.add_monitor(Monitor(c1, ["s", "v"]), get_c1_name(size))

        max_pool = MaxPool2dConnection(s1, c1, kernel_size=2, stride=2, decay=0.2)
        network.add_connection(max_pool, get_s1_name(size), get_c1_name(size))

    for feature in FEATURES:
        for size in FILTER_SIZES:
            s2 = LIFNodes(shape=(1, IMAGE_SIZE // 2, IMAGE_SIZE // 2), thresh=-64, traces=True)
            network.add_layer(layer=s2, name=get_s2_name(size, feature))
            # network.add_monitor(Monitor(s2, ["s", "v"]), get_s2_name(size, feature))

            conv = Conv2dConnection(network.layers[get_c1_name(size)], s2, 15, padding=7,
                                    update_rule=PostPre, wmin=0, wmax=1)

            network.add_monitor(
                Monitor(conv, ["w"]),
                "conv%d%d" % (feature, size)
            )

            network.add_connection(conv, get_c1_name(size), get_s2_name(size, feature))

            c2 = LIFNodes(shape=(1, 1, 1), thresh=-64, traces=True)
            network.add_layer(layer=c2, name=get_c2_name(size, feature))
            # network.add_monitor(Monitor(c2, ["s", "v"]), get_c2_name(size, feature))

            max_pool = MaxPool2dConnection(s2, c2, kernel_size=IMAGE_SIZE // 2, decay=0.0)
            network.add_connection(max_pool, get_s2_name(size, feature), get_c2_name(size, feature))
Esempio n. 7
0
    def __init__(self,
                 agent: Agent,
                 environment: GymEnvironment,
                 action_function: Optional[Callable] = None,
                 **kwargs):

        super(PongPipeline, self).__init__(agent.model, **kwargs)

        self.agent = agent
        if not isinstance(self.agent.model, SNN):
            raise TypeError('Only SNN agent can be used in PongPipeline.')

        self.network = self.agent.model
        self.device = self.agent.device
        self.network = self.network.to(self.device)

        self.env = environment
        self.action_function = action_function

        self.accumulated_reward = 0.0
        self.reward_list = []

        self.output = kwargs.get('output', None)
        self.render_interval = kwargs.get('render_interval', None)
        self.reward_delay = kwargs.get('reward_delay', None)
        self.time = kwargs.get('time', int(self.agent.model.dt))
        self.skip_first_frame = kwargs.get('skip_first_frame', True)
        self.replay_buffer = kwargs.get('replay_buffer', None)
        if self.replay_buffer is None:
            warnings.warn(
                'Please use replay buffer to handle sparse rewarding condition.'
            )

        if self.reward_delay is not None:
            assert self.reward_delay > 0
            self.rewards = torch.zeros(self.reward_delay)

        # Set up for multiple layers of input layers.
        self.inputs = [
            name for name, layer in self.network.layers.items()
            if isinstance(layer, AbstractInput)
        ]

        self.action = None

        self.voltage_record = None
        self.threshold_value = None

        self.first = True

        if self.output is not None:
            self.network.add_monitor(
                Monitor(self.network.layers[self.output], ["s"]), self.output)

            self.spike_record = {
                self.output: torch.zeros((self.time, len(self.agent.action)))
            }
Esempio n. 8
0
 def create_monitors(self, time):
     monitors = {}
     for layer in set(self.network.layers) - {'X'}:
         monitors[layer] = Monitor(self.network.layers[layer],
                                   state_vars=['s'],
                                   time=time)
         self.network.add_monitor(monitors[layer],
                                  name='%s_monitor' % layer)
     return monitors
def LIF(nodes_network):
    LIF = LIFNodes(n=2, traces=True)
    nodes_network.add_layer(layer=LIF, name="LIF")
    nodes_network.add_connection(connection=Connection(source=input_layer,
                                                       target=LIF),
                                 source="Input",
                                 target="LIF")
    LIF_monitor = Monitor(obj=LIF, state_vars=("s", "v"))
    nodes_network.add_monitor(monitor=LIF_monitor, name="LIF monitor")
    return ("LIF", LIF_monitor)
Esempio n. 10
0
    def load(self, file_path):
        self.network = load(file_path)
        self.n_iter = 60000

        dt = 1
        intensity = 127.5

        self.train_dataset = MNIST(
            PoissonEncoder(time=self.time_max, dt=dt),
            None,
            "MNIST",
            download=False,
            train=True,
            transform=transforms.Compose(
                [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
                )
            )

        self.spikes = {}
        for layer in set(self.network.layers):
            self.spikes[layer] = Monitor(self.network.layers[layer], state_vars=["s"], time=self.time_max)
            self.network.add_monitor(self.spikes[layer], name="%s_spikes" % layer)
            #print('GlobalMonitor.state_vars:', self.GlobalMonitor.state_vars)

        self.voltages = {}
        for layer in set(self.network.layers) - {"X"}:
            self.voltages[layer] = Monitor(self.network.layers[layer], state_vars=["v"], time=self.time_max)
            self.network.add_monitor(self.voltages[layer], name="%s_voltages" % layer)

        weights_XY = self.network.connections[('X', 'Y')].w

        weights_XY = weights_XY.reshape(28, 28, -1)
        weights_to_display = torch.zeros(0, 28*25)
        i = 0
        while i < 625:
            for j in range(25):
                weights_to_display_row = torch.zeros(28, 0)
                for k in range(25):
                    weights_to_display_row = torch.cat((weights_to_display_row, weights_XY[:, :, i]), dim=1)
                    i += 1
                weights_to_display = torch.cat((weights_to_display, weights_to_display_row), dim=0)

        self.weights_XY = weights_to_display.numpy()
def CurrentLIF(nodes_network):
    CurrentLIF = CurrentLIFNodes(n=1, traces=True)
    nodes_network.add_layer(layer=CurrentLIF, name="CurrentLIF")
    nodes_network.add_connection(connection=Connection(source=input_layer,
                                                       target=CurrentLIF),
                                 source="Input",
                                 target="CurrentLIF")
    CurrentLIF_monitor = Monitor(obj=CurrentLIF, state_vars=("s", "v"))
    nodes_network.add_monitor(monitor=CurrentLIF_monitor,
                              name="CurrentLIF monitor")
    return ("CurrentLIF", CurrentLIF_monitor)
def AdaptiveLIF(nodes_network):
    AdaptiveLIF = AdaptiveLIFNodes(n=1, traces=True)
    nodes_network.add_layer(layer=AdaptiveLIF, name="AdaptiveLIF")
    nodes_network.add_connection(connection=Connection(source=input_layer,
                                                       target=AdaptiveLIF),
                                 source="Input",
                                 target="AdaptiveLIF")
    AdaptiveLIF_monitor = Monitor(obj=AdaptiveLIF, state_vars=("s", "v"))
    nodes_network.add_monitor(monitor=AdaptiveLIF_monitor,
                              name="AdaptiveLIF monitor")
    return ("AdaptiveLIF", AdaptiveLIF_monitor)
def Izhikevich(nodes_network):
    Izhikevich = IzhikevichNodes(n=1, traces=True)
    nodes_network.add_layer(layer=Izhikevich, name="Izhikevich")
    nodes_network.add_connection(connection=Connection(source=input_layer,
                                                       target=Izhikevich),
                                 source="Input",
                                 target="Izhikevich")
    Izhikevich_monitor = Monitor(obj=Izhikevich, state_vars=("s", "v"))
    nodes_network.add_monitor(monitor=Izhikevich_monitor,
                              name="Izhikevich monitor")
    return ("Izhikevich", Izhikevich_monitor)
Esempio n. 14
0
class TestMonitor:
    """
    Testing Monitor object.
    """

    network = Network()

    inpt = Input(75)
    network.add_layer(inpt, name="X")
    _if = IFNodes(25)
    network.add_layer(_if, name="Y")
    conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n))
    network.add_connection(conn, source="X", target="Y")

    inpt_mon = Monitor(inpt, state_vars=["s"])
    network.add_monitor(inpt_mon, name="X")
    _if_mon = Monitor(_if, state_vars=["s", "v"])
    network.add_monitor(_if_mon, name="Y")

    network.run(
        inputs={"X": torch.bernoulli(torch.rand(100, inpt.n))}, time=100
    )

    assert inpt_mon.get("s").size() == torch.Size([100, 1, inpt.n])
    assert _if_mon.get("s").size() == torch.Size([100, 1, _if.n])
    assert _if_mon.get("v").size() == torch.Size([100, 1, _if.n])

    del network.monitors["X"], network.monitors["Y"]

    inpt_mon = Monitor(inpt, state_vars=["s"], time=500)
    network.add_monitor(inpt_mon, name="X")
    _if_mon = Monitor(_if, state_vars=["s", "v"], time=500)
    network.add_monitor(_if_mon, name="Y")

    network.run(
        inputs={"X": torch.bernoulli(torch.rand(500, inpt.n))}, time=500
    )

    assert inpt_mon.get("s").size() == torch.Size([500, 1, inpt.n])
    assert _if_mon.get("s").size() == torch.Size([500, 1, _if.n])
    assert _if_mon.get("v").size() == torch.Size([500, 1, _if.n])
Esempio n. 15
0
def main(n_input=1, n_output=10, time=1000):
    # Network building.
    network = Network(dt=1.0)
    input_layer = RealInput(n=n_input)
    output_layer = LIFNodes(n=n_output)
    connection = Connection(source=input_layer, target=output_layer)
    monitor = Monitor(obj=output_layer, state_vars=('v', ), time=time)

    # Adding network components.
    network.add_layer(input_layer, name='X')
    network.add_layer(output_layer, name='Y')
    network.add_connection(connection, source='X', target='Y')
    network.add_monitor(monitor, name='X_monitor')

    # Creating real-valued inputs and running simulation.
    inpts = {'X': torch.ones(time, n_input)}
    network.run(inpts=inpts, time=time)

    # Plot voltage activity.
    plt.plot(monitor.get('v').numpy().T)
    plt.show()
Esempio n. 16
0
class TestMonitor:
    """
    Testing Monitor object.
    """
    network = Network()

    inpt = Input(75)
    network.add_layer(inpt, name='X')
    _if = IFNodes(25)
    network.add_layer(_if, name='Y')
    conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n))
    network.add_connection(conn, source='X', target='Y')

    inpt_mon = Monitor(inpt, state_vars=['s'])
    network.add_monitor(inpt_mon, name='X')
    _if_mon = Monitor(_if, state_vars=['s', 'v'])
    network.add_monitor(_if_mon, name='Y')

    network.run(inpts={'X': torch.bernoulli(torch.rand(100, inpt.n))},
                time=100)

    assert inpt_mon.get('s').size() == torch.Size([inpt.n, 100])
    assert _if_mon.get('s').size() == torch.Size([_if.n, 100])
    assert _if_mon.get('v').size() == torch.Size([_if.n, 100])

    del network.monitors['X'], network.monitors['Y']

    inpt_mon = Monitor(inpt, state_vars=['s'], time=500)
    network.add_monitor(inpt_mon, name='X')
    _if_mon = Monitor(_if, state_vars=['s', 'v'], time=500)
    network.add_monitor(_if_mon, name='Y')

    network.run(inpts={'X': torch.bernoulli(torch.rand(500, inpt.n))},
                time=500)

    assert inpt_mon.get('s').size() == torch.Size([inpt.n, 500])
    assert _if_mon.get('s').size() == torch.Size([_if.n, 500])
    assert _if_mon.get('v').size() == torch.Size([_if.n, 500])
Esempio n. 17
0
    def _init_network_monitor(self, network, cfg):
        exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"],
                                      time=cfg['time'])
        inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"],
                                      time=cfg['time'])
        network.add_monitor(exc_voltage_monitor, name="exc_voltage")
        network.add_monitor(inh_voltage_monitor, name="inh_voltage")

        spikes = {}
        for layer in set(network.layers):
            spikes[layer] = Monitor(network.layers[layer],
                                    state_vars=["s"],
                                    time=cfg['time'])
            network.add_monitor(spikes[layer], name="%s_spikes" % layer)

        voltages = {}
        for layer in set(network.layers) - {"X"}:
            voltages[layer] = Monitor(network.layers[layer],
                                      state_vars=["v"],
                                      time=cfg['time'])
            network.add_monitor(voltages[layer], name="%s_voltages" % layer)

        return exc_voltage_monitor, inh_voltage_monitor, spikes, voltages
Esempio n. 18
0
	def test_add_objects(self):
		network = Network(dt=1.0)
		
		inpt = Input(100); network.add_layer(inpt, name='X')
		lif = LIFNodes(50); network.add_layer(lif, name='Y')
		
		assert inpt == network.layers['X']
		assert lif == network.layers['Y']
		
		conn = Connection(inpt, lif); network.add_connection(conn, source='X', target='Y')
		
		assert conn == network.connections[('X', 'Y')]
		
		monitor = Monitor(lif, state_vars=['s', 'v']); network.add_monitor(monitor, 'Y')
		
		assert monitor == network.monitors['Y']
Esempio n. 19
0
    def __init__(self, parameters: BenchmarkParameters):
        super(BindsNetModule, self).__init__()
        network = Network(batch_size=parameters.batch_size, dt=parameters.dt)
        lif_nodes = LIFNodes(n=parameters.features)
        monitor = Monitor(obj=lif_nodes,
                          state_vars=("s"),
                          time=parameters.sequence_length)
        network.add_layer(Input(n=parameters.features), name="Input")
        network.add_layer(lif_nodes, name="Neurons")
        network.add_connection(
            Connection(source=network.layers["Input"],
                       target=network.layers["Neurons"]),
            source="Input",
            target="Neurons",
        )
        network.add_monitor(monitor, "Monitor")
        network.to(parameters.device)

        self.parameters = parameters
        self.network = network
        self.monitor = monitor
Esempio n. 20
0
    def __init__(self,
                 encoder,
                 dt: float = 1.0,
                 lag: int = 10,
                 n_neurons: int = 100,
                 time: int = 100,
                 learning: bool = False):
        super().__init__(dt=dt)
        self.learning = learning
        self.n_neurons = n_neurons
        self.lag = lag
        self.encoder = encoder
        self.time = time

        for i in range(lag):
            self.add_layer(RealInput(n=encoder.e_size, traces=True),
                           name=f'input_{i+1}')
            self.add_layer(LIFNodes(n=self.n_neurons, traces=True),
                           name=f'column_{i+1}')
            self.add_monitor(Monitor(self.layers[f'column_{i+1}'], ['s'],
                                     time=self.time),
                             name=f'monitor_{i+1}')
            w = 0.3 * torch.rand(self.encoder.e_size, self.n_neurons)
            self.add_connection(Connection(source=self.layers[f'input_{i+1}'],
                                           target=self.layers[f'column_{i+1}'],
                                           w=w),
                                source=f'input_{i+1}',
                                target=f'column_{i+1}')

        for i in range(lag):
            for j in range(lag):
                w = torch.zeros(self.n_neurons, self.n_neurons)
                self.add_connection(Connection(
                    source=self.layers[f'column_{i+1}'],
                    target=self.layers[f'column_{j+1}'],
                    w=w,
                    update_rule=Hebbian,
                    nu=args.nu),
                                    source=f'column_{i+1}',
                                    target=f'column_{j+1}')
Esempio n. 21
0
def add_decision_layers(network):
    output = LIFNodes(n=len(SUBJECTS), thresh=-60, traces=True)
    network.add_layer(output, "OUT")
    network.add_monitor(Monitor(output, ["s", "v"]), "OUT")

    for feature in FEATURES:
        for size in FILTER_SIZES:
            connection = Connection(
                source=network.layers[get_c2_name(size, feature)],
                target=output,
                w=0.05 + 0.1 * torch.randn(
                    network.layers[get_c2_name(size, feature)].n, output.n),
                update_rule=PostPre)
            network.add_connection(connection, get_c2_name(size, feature),
                                   "OUT")

    rec_connection = Connection(
        source=output,
        target=output,
        w=0.05 * (torch.eye(output.n) - 1),
        decay=0.0,
    )
    network.add_connection(rec_connection, "OUT", "OUT")
Esempio n. 22
0
def convert(ann, dataset):
    snn = ann_to_snn(ann, input_shape=(1, 10 * 25), data=dataset)
    if not args.no_negative:
        nr = SubtractiveResetIFNodes(50, [50],
                                     False,
                                     refrac=0,
                                     reset=0,
                                     thresh=1)
        nr.dt = 1.0
        nr.network = snn
        snn.layers['2'] = nr
        nw = torch.zeros((snn.connections[('1', '2')].w.shape[0],
                          snn.connections[('1', '2')].w.shape[1] * 2))
        nw[:, :25] = snn.connections[('1', '2')].w
        nw[:, 25:] = -snn.connections[('1', '2')].w
        b = torch.zeros(50)
        b[:25] = snn.connections[('1', '2')].b
        b[25:] = -snn.connections[('1', '2')].b
        snn.connections[('1', '2')].w = nw
        snn.connections[('1', '2')].b = b
        snn.connections[('1', '2')].target = snn.layers['2']
    snn.add_monitor(monitor=Monitor(obj=snn.layers['2'], state_vars=['s']),
                    name='output_monitor')
    return snn
Esempio n. 23
0
n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
per_class = int(n_neurons / 10)

# Build Diehl & Cook 2015 network.
network = DiehlAndCook2015(n_inpt=784,
                           n_neurons=n_neurons,
                           exc=exc,
                           inh=inh,
                           dt=dt,
                           norm=78.4,
                           nu=[0.0, 1e-2],
                           inpt_shape=(1, 28, 28))

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time)
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time)
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

# Load MNIST data.
dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x * intensity)]),
)
Esempio n. 24
0
def main(args):
    if args.update_steps is None:
        args.update_steps = max(
            250 // args.batch_size, 1
        )  #Its value is 16 # why is it always multiplied with step? #update_steps is how many batch to classify before updating the graphs

    update_interval = args.update_steps * args.batch_size  # Value is 240 #update_interval is how many pictures to classify before updating the graphs

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(
            args.seed
        )  #to enable reproducability of the code to get the same result
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":  #could have used switch to improve performance
        reduction = torch.sum  #weight updates for the batch
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(  #Changed here
        n_inpt=784,  # input dimensions are 28x28=784
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(1e-4, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose(  #Composes several transforms together
            [
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * args.intensity)
            ]),
    )

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10  #changed
    assignments = -torch.ones(args.n_neurons)  #assignments is set to -1
    proportions = torch.zeros(args.n_neurons,
                              n_classes)  #matrix of 100x10 filled with zeros
    rates = torch.zeros(args.n_neurons,
                        n_classes)  #matrix of 100x10 filled with zeros

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(
            network.layers[layer], state_vars=["s"], time=args.time
        )  # Monitors:  Records state variables of interest. obj:An object to record state variables from during network simulation.
        network.add_monitor(
            spikes[layer], name="%s_spikes" % layer
        )  #state_vars: Iterable of strings indicating names of state variables to record.
        #param time: If not ``None``, pre-allocate memory for state variable recording.
    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(
            args.log_dir):  #checks if the path is a existing directory
        shutil.rmtree(
            args.log_dir)  # is used to delete an entire directory tree

    # Summary writer.
    writer = SummaryWriter(
        log_dir=args.log_dir, flush_secs=60
    )  #SummaryWriter: these utilities let you log PyTorch models and metrics into a directory for visualization
    #flush_secs:  in seconds, to flush the pending events and summaries to disk.
    for epoch in range(args.n_epochs):  #default is 1
        print("\nEpoch: {epoch}\n")

        labels = []

        # Create a dataloader to iterate and batch data
        dataloader = DataLoader(  #It represents a Python iterable over a dataset
            dataset,
            batch_size=args.batch_size,  #how many samples per batch to load
            shuffle=
            True,  #set to True to have the data reshuffled at every epoch
            num_workers=args.n_workers,
            pin_memory=args.
            gpu,  #If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
        )

        for step, batch in enumerate(
                dataloader
        ):  #Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object
            print("Step:", step)

            global_step = 60000 * epoch + args.batch_size * step

            if step % args.update_steps == 0 and step > 0:

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/all vote",
                    scalar_value=torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                #Vennila: Records the accuracies in each step
                value = torch.mean(
                    (label_tensor.long() == all_activity_pred).float())
                value = value.item()
                accuracy.append(value)
                print("ACCURACY:", value)
                writer.add_scalar(
                    tag="accuracy/proportion weighting",
                    scalar_value=torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="spikes/mean",
                    scalar_value=torch.mean(torch.sum(spike_record, dim=1)),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                labels = []

            labels.extend(
                batch["label"].tolist()
            )  #for each batch or 16 pictures the labels of it is added to this list

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {
                    k: v.cuda()
                    for k, v in inpts.items()
                }  #.cuda() is used to set up and run CUDA operations in the selected GPU

            # Run the network on the input.
            t0 = time()
            network.run(inputs=inpts, time=args.time, one_step=args.one_step
                        )  # Simulate network for given inputs and time.
            t1 = time() - t0

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            writer.add_scalar(tag="time/simulation",
                              scalar_value=t1,
                              global_step=global_step)
            # if(step==1):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data = asarray(an_array)
            #     savetxt('data.csv',data)
            #     print("Beginning weights saved")
            # if(step==3749):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data2 = asarray(an_array)
            #     savetxt('data2.csv',data2)
            #     print("Ending weights saved")
            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                # print("Weights:",input_exc_weights)
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_state_variables()
        print(end_accuracy())  #Vennila
Esempio n. 25
0
        if fltr1 != fltr2:
            for i in range(conv_size):
                for j in range(conv_size):
                    w[fltr1, i, j, fltr2, i, j] = -100.0

w = w.view(n_filters * conv_size * conv_size,
           n_filters * conv_size * conv_size)
recurrent_conn = Connection(conv_layer, conv_layer, w=w)

network.add_layer(input_layer, name="X")
network.add_layer(conv_layer, name="Y")
network.add_connection(conv_conn, source="X", target="Y")
network.add_connection(recurrent_conn, source="Y", target="Y")

# Voltage recording for excitatory and inhibitory layers.
voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time)
network.add_monitor(voltage_monitor, name="output_voltage")

if gpu:
    network.to("cuda")

# Load MNIST data.
train_dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    "../../data/MNIST",
    download=True,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x * intensity)]),
def main(seed=0,
         n_train=60000,
         n_test=10000,
         kernel_size=(16, ),
         stride=(4, ),
         n_filters=25,
         padding=0,
         inhib=100,
         time=25,
         lr=1e-3,
         lr_decay=0.99,
         dt=1,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, n_train, kernel_size, stride, n_filters, padding, inhib, time,
        lr, lr_decay, dt, intensity, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, n_train, n_test, kernel_size, stride, n_filters, padding,
            inhib, time, lr, lr_decay, dt, intensity, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    input_shape = [20, 20]

    if kernel_size == input_shape:
        conv_size = [1, 1]
    else:
        conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1,
                     int((input_shape[1] - kernel_size[1]) / stride[1]) + 1)

    n_classes = 10
    n_neurons = n_filters * np.prod(conv_size)
    total_kernel_size = int(np.prod(kernel_size))
    total_conv_size = int(np.prod(conv_size))

    # Build network.
    if train:
        network = Network()
        input_layer = Input(n=400, shape=(1, 1, 20, 20), traces=True)
        conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size,
                                       shape=(1, n_filters, *conv_size),
                                       thresh=-64.0,
                                       traces=True,
                                       theta_plus=0.05 * (kernel_size[0] / 20),
                                       refrac=0)
        conv_layer2 = LIFNodes(n=n_filters * total_conv_size,
                               shape=(1, n_filters, *conv_size),
                               refrac=0)
        conv_conn = Conv2dConnection(input_layer,
                                     conv_layer,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     update_rule=WeightDependentPostPre,
                                     norm=0.05 * total_kernel_size,
                                     nu=[0, lr],
                                     wmin=0,
                                     wmax=0.25)
        conv_conn2 = Conv2dConnection(input_layer,
                                      conv_layer2,
                                      w=conv_conn.w,
                                      kernel_size=kernel_size,
                                      stride=stride,
                                      update_rule=None,
                                      wmax=0.25)

        w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1],
                                n_filters, conv_size[0], conv_size[1])
        for f in range(n_filters):
            for f2 in range(n_filters):
                if f != f2:
                    w[f, :, :f2, :, :] = 0

        w = w.view(n_filters * conv_size[0] * conv_size[1],
                   n_filters * conv_size[0] * conv_size[1])
        recurrent_conn = Connection(conv_layer, conv_layer, w=w)

        network.add_layer(input_layer, name='X')
        network.add_layer(conv_layer, name='Y')
        network.add_layer(conv_layer2, name='Y_')
        network.add_connection(conv_conn, source='X', target='Y')
        network.add_connection(conv_conn2, source='X', target='Y_')
        network.add_connection(recurrent_conn, source='Y', target='Y')

        # Voltage recording for excitatory and inhibitory layers.
        voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
        network.add_monitor(voltage_monitor, name='output_voltage')
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity
    images = images[:, 4:-4, 4:-4].contiguous()

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_examples, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs',
                                          max_iter=1000,
                                          multi_class='multinomial')
        logreg_model.coef_ = np.zeros([n_classes, n_neurons])
        logreg_model.intercept_ = np.zeros(n_classes)
        logreg_model.classes_ = np.arange(n_classes)
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        logreg_coef, logreg_intercept = torch.load(open(path, 'rb'))
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs',
                                          max_iter=1000,
                                          multi_class='multinomial')
        logreg_model.coef_ = logreg_coef
        logreg_model.intercept_ = logreg_intercept
        logreg_model.classes_ = np.arange(n_classes)

    # Sequence of accuracy estimates.
    curves = {'logreg': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_ims = None
    inpt_axes = None
    spike_ims = None
    spike_axes = None
    weights_im = None

    plot_update_interval = 100

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print('Progress: %d / %d (%.4f seconds)' %
                  (i, n_examples, t() - start))
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
                current_record = full_spike_record[-update_interval:]
            else:
                current_labels = labels[i % len(labels) - update_interval:i %
                                        len(labels)]
                current_record = full_spike_record[i % len(labels) -
                                                   update_interval:i %
                                                   len(labels)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          full_spike_record=current_record,
                                          logreg=logreg_model)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((logreg_model.coef_, logreg_model.intercept_),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Refit logistic regression model.
                logreg_model = logreg_fit(full_spike_record[:i], labels[:i],
                                          logreg_model)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = bernoulli(datum=image, time=time, dt=dt,
                           max_prob=1).unsqueeze(1).unsqueeze(1)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        network.connections['X', 'Y_'].w = network.connections['X', 'Y'].w

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y_'].get('s').view(
            time, -1)
        full_spike_record[i] = spikes['Y_'].get('s').view(time, -1).sum(0)

        # Optionally plot various simulation information.
        if plot and i % plot_update_interval == 0:
            _input = inpts['X'].view(time, 400).sum(0).view(20, 20)
            w = network.connections['X', 'Y'].w

            _spikes = {
                'X': spikes['X'].get('s').view(400, time),
                'Y': spikes['Y'].get('s').view(n_filters * total_conv_size,
                                               time),
                'Y_': spikes['Y_'].get('s').view(n_filters * total_conv_size,
                                                 time)
            }

            inpt_axes, inpt_ims = plot_input(image.view(20, 20),
                                             _input,
                                             label=labels[i % len(labels)],
                                             ims=inpt_ims,
                                             axes=inpt_axes)
            spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_conv2d_weights(
                w, im=weights_im, wmax=network.connections['X', 'Y'].wmax)

            plt.pause(1e-2)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
        current_record = full_spike_record[-update_interval:]
    else:
        current_labels = labels[i % len(labels) - update_interval:i %
                                len(labels)]
        current_record = full_spike_record[i % len(labels) -
                                           update_interval:i % len(labels)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  full_spike_record=current_record,
                                  logreg=logreg_model)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path,
                                '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((logreg_model.coef_, logreg_model.intercept_),
                       open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [np.mean(curves['logreg']), np.std(curves['logreg'])]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                columns = [
                    'seed', 'n_train', 'kernel_size', 'stride', 'n_filters',
                    'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt',
                    'intensity', 'update_interval', 'mean_logreg', 'std_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)
            else:
                columns = [
                    'seed', 'n_train', 'n_test', 'kernel_size', 'stride',
                    'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay',
                    'dt', 'intensity', 'update_interval', 'mean_logreg',
                    'std_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))
Esempio n. 27
0
network.add_connection(FF1b, source="I_b", target="TNN_1b")
network.add_connection(FF2a, source="TNN_1a", target="rTNN_1")
network.add_connection(FF2b, source="TNN_1b", target="rTNN_1")
# (Recurrences)
network.add_connection(rTNN_to_buf1, source="rTNN_1", target="BUF_1")
# network.add_connection(buf1_to_buf2, source="BUF_1", target="BUF_2")
network.add_connection(buf1_to_rTNN, source="BUF_1", target="rTNN_1")
# network.add_connection(buf2_to_rTNN, source="BUF_2", target="rTNN_1")


# End of network creation

# Monitors:
spikes = {}
for l in network.layers:
	spikes[l] = Monitor(network.layers[l], ["s"], time=num_timesteps)
	network.add_monitor(spikes[l], name="%s_spikes" % l)


# Data and initial encoding:
dataset = MNIST(
	RampNoLeakTNNEncoder(time=num_timesteps, dt=1),
	None,
	root=os.path.join("..", "..", "data", "MNIST"),
	download=True,
	transform=transforms.Compose(
		[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
	),
)

Esempio n. 28
0
network.add_connection(connection=Parallelfiber,
                       source="GR_Joint_layer",
                       target="PK")
network.add_connection(connection=Parallelfiber_Anti,
                       source="GR_Joint_layer",
                       target="PK_Anti")
network.add_connection(connection=Climbingfiber, source="IO", target="PK")
network.add_connection(connection=Climbingfiber_Anti,
                       source="IO_Anti",
                       target="PK_Anti")
network.add_connection(connection=PK_DCN, source="PK", target="DCN")
network.add_connection(connection=PK_DCN_Anti,
                       source="PK_Anti",
                       target="DCN_Anti")

GR_monitor = Monitor(obj=GR_Joint_layer, state_vars=("s"), time=time)
PK_monitor = Monitor(obj=PK, state_vars=("s", "v"), time=time)
PK_Anti_monitor = Monitor(obj=PK_Anti, state_vars=("s", "v"), time=time)

IO_monitor = Monitor(obj=IO, state_vars=("s"), time=time)
DCN_monitor = Monitor(
    obj=DCN,
    state_vars=("s", "v"),
    time=time,
)

DCN_Anti_monitor = Monitor(obj=DCN_Anti, state_vars=("s", "v"), time=time)
network.add_monitor(monitor=GR_monitor, name="GR")
network.add_monitor(monitor=PK_monitor, name="PK")
network.add_monitor(monitor=PK_Anti_monitor, name="PK_Anti")
network.add_monitor(monitor=IO_monitor, name="IO")
Esempio n. 29
0
network = Network(dt=1.0)
inpt = Input(784, shape=(28, 28))
network.add_layer(inpt, name="I")
output = LIFNodes(625, thresh=-52 + torch.randn(625))
network.add_layer(output, name="O")
C1 = Connection(source=inpt, target=output, w=torch.randn(inpt.n, output.n))
C2 = Connection(source=output,
                target=output,
                w=0.5 * torch.randn(output.n, output.n))

network.add_connection(C1, source="I", target="O")
network.add_connection(C2, source="O", target="O")

spikes = {}
for l in network.layers:
    spikes[l] = Monitor(network.layers[l], ["s"], time=250)
    network.add_monitor(spikes[l], name="%s_spikes" % l)

voltages = {"O": Monitor(network.layers["O"], ["v"], time=250)}
network.add_monitor(voltages["O"], name="O_voltages")

# Get MNIST training images and labels.
images, labels = MNIST(path="../../data/MNIST", download=True).get_train()
images *= 0.25

# Create lazily iterating Poisson-distributed data loader.
loader = zip(poisson_loader(images, time=250), iter(labels))

inpt_axes = None
inpt_ims = None
spike_axes = None
def main(seed=0,
         n_train=60000,
         n_test=10000,
         inhib=250,
         kernel_size=(16, ),
         stride=(2, ),
         time=50,
         n_filters=25,
         crop=0,
         lr=1e-2,
         lr_decay=0.99,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         norm=0.2,
         progress_interval=10,
         update_interval=250,
         train=True,
         relabel=False,
         plot=False,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0 or relabel, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train,
        inhib, time, dt, theta_plus, theta_decay, norm, progress_interval,
        update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train,
            n_test, inhib, time, dt, theta_plus, theta_decay, norm,
            progress_interval, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    side_length = 28 - crop * 2
    n_inpt = side_length**2
    n_examples = n_train if train else n_test
    n_classes = 10

    # Build network.
    if train:
        network = LocallyConnectedNetwork(
            n_inpt=n_inpt,
            input_shape=[side_length, side_length],
            kernel_size=kernel_size,
            stride=stride,
            n_filters=n_filters,
            inh=inhib,
            dt=dt,
            nu=[.1 * lr, lr],
            theta_plus=theta_plus,
            theta_decay=theta_decay,
            wmin=0,
            wmax=1.0,
            norm=norm)
        network.layers['Y'].thresh = 1
        network.layers['Y'].reset = 0
        network.layers['Y'].rest = 0

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    conv_size = network.connections['X', 'Y'].conv_size
    locations = network.connections['X', 'Y'].locations
    conv_prod = int(np.prod(conv_size))
    n_neurons = n_filters * conv_prod

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load Fashion-MNIST data.
    dataset = FashionMNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    if crop != 0:
        images = images[:, crop:-crop, crop:-crop]

    # Record spikes during the simulation.
    if not train:
        update_interval = n_examples

    spike_record = torch.zeros(update_interval, time, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, 10))
        rates = torch.zeros_like(torch.Tensor(n_neurons, 10))
        ngram_scores = {}
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    if train:
        best_accuracy = 0

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name=f'{layer}_spikes')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0 and train:
            network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))

                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % len(images)].contiguous().view(-1)
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 5 and retries < 3:
            retries += 1
            image *= 2
            sample = poisson(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            _spikes = {
                'X': spikes['X'].get('s').view(side_length**2, time),
                'Y': spikes['Y'].get('s').view(n_filters * conv_prod, time)
            }

            spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_locally_connected_weights(
                network.connections['X', 'Y'].w,
                n_filters,
                kernel_size,
                conv_size,
                locations,
                side_length,
                im=weights_im,
                wmin=0,
                wmax=1)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    if not train and relabel:
        # Assign labels to excitatory layer neurons.
        assignments, proportions, rates = assign_labels(
            spike_record, current_labels, n_classes, rates)

        # Compute ngram scores.
        ngram_scores = update_ngram_scores(spike_record, current_labels,
                                           n_classes, 2, ngram_scores)

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  ngram_scores=ngram_scores,
                                  n=2)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path,
                                '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((assignments, proportions, rates, ngram_scores),
                       open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    path = os.path.join('..', '..', 'results', data, model)
    if not os.path.isdir(path):
        os.makedirs(path)

    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,kernel_size,stride,n_filters,crop,n_train,inhib,time,lr,lr_decay,timestep,theta_plus,'
                    'theta_decay,norm,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,kernel_size,stride,n_filters,crop,n_train,n_test,inhib,time,lr,lr_decay,timestep,'
                    'theta_plus,theta_decay,norm,progress_interval,update_interval,mean_all_activity,'
                    'mean_proportion_weighting,mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))