Example #1
0
    def __init__(self, spec: specs.UnitSpec = None) -> None:
        if spec is None:
            self.spec = specs.UnitSpec()
        else:
            self.spec = spec

        # Net input (excitation) without time integration
        self.net_raw = 0.0
        # Net inpput (excitation) with time integration
        self.net = 0.0
        # Total (feedback + feedforward) inhibition
        self.gc_i = 0.0
        # Activation
        self.act = 0.0
        # Net current
        self.i_net = 0.0
        # Net current, rate-coded (driven by v_m_eq)
        self.i_net_r = 0.0
        # Membrane potential
        self.v_m = 0.0
        # Equilibrium membrane potential (does not reset on spike)
        self.v_m_eq = 0.0
        # Adaption current
        self.adapt = 0.0
        # Are we spiking? (0 or 1)
        self.spike = 0
Example #2
0
def test_kwta_suppresses_all_but_k_units() -> None:
    n = net.Net()

    n.new_layer(name="lr1", size=1)

    lr2_spec = sp.LayerSpec(
        inhibition_type="kwta",
        k=2,
        log_on_cycle=("unit_act", ),
        unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0))
    n.new_layer(name="lr2", size=3, spec=lr2_spec)

    pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3))
    n.new_projn("proj1", "lr1", "lr2", pr1_spec)

    pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1])
    n.new_projn("proj2", "lr1", "lr2", pr2_spec)

    n.force_layer("lr1", [1])
    for i in range(100):
        n.cycle()

    logs = n.logs("cycle", "lr2")
    acts = logs[logs.time == 99]["act"]
    assert (acts > 0.8).sum() == 2
Example #3
0
    def __init__(self, size: int, spec: specs.UnitSpec = None) -> None:
        if size <= 0:
            raise ValueError("size must be greater than zero.")
        self.size = size

        if spec is None:
            self.spec = specs.UnitSpec()
        else:
            self.spec = spec

        # When adding any attribute to this class, update
        # layer.LayerSpec._valid_log_on_cycle

        # Net input (excitation) without time integration
        self.net_raw = torch.Tensor(self.size).zero_()
        # Net inpput (excitation) with time integration
        self.net = torch.Tensor(self.size).zero_()
        # Total (feedback + feedforward) inhibition
        self.gc_i = torch.Tensor(self.size).zero_()
        # Activation
        self.act = torch.Tensor(self.size).zero_()
        # Net current
        self.i_net = torch.Tensor(self.size).zero_()
        # Net current, rate-coded (driven by v_m_eq)
        self.i_net_r = torch.Tensor(self.size).zero_()
        # Membrane potential
        self.v_m = torch.Tensor(self.size).zero_()
        # Equilibrium membrane potential (does not reset on spike)
        self.v_m_eq = torch.Tensor(self.size).zero_()
        # Adaption current
        self.adapt = torch.Tensor(self.size).zero_()
        # Are we spiking? (0 or 1)
        # In the future, this could be a ByteTensor
        self.spike = torch.Tensor(self.size).zero_()
Example #4
0
def test_unit_spec_validates_integ(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(integ=f).validate()
Example #5
0
def test_it_should_validate_the_unit_spec() -> None:
    with pytest.raises(sp.ValidationError):
        sp.LayerSpec(unit_spec=sp.UnitSpec(vm_dt=-1)).validate()
Example #6
0
def test_it_should_validate_l_up_inc(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(l_up_inc=f).validate()
Example #7
0
def test_it_should_validate_m_dt(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(m_dt=f).validate()
Example #8
0
def test_it_should_validate_act_gain_positive(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(act_gain=f).validate()
Example #9
0
def test_it_should_validate_syn_tr(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(syn_tr=f).validate()
Example #10
0
def test_it_should_validate_v_m_r_is_less_than_spk_thr() -> None:
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(v_m_r=1, spk_thr=0.5).validate()
Example #11
0
def test_it_should_validate_adapt_dt(f):
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(adapt_dt=f).validate()
Example #12
0
def test_it_should_validate_v_m_r_for_insane_floats(f):
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(v_m_r=f).validate()
Example #13
0
def test_unit_init_uses_the_spec_you_pass_it() -> None:
    foo = sp.UnitSpec()
    unit = un.Unit(spec=foo)
    assert unit.spec is foo
Example #14
0
def test_unitgroup_sets_the_spec_you_provide() -> None:
    spec = sp.UnitSpec()
    assert un.UnitGroup(size=3, spec=spec).spec is spec
Example #15
0
def test_unitgroup_uses_the_default_spec_if_none_is_provided() -> None:
    group = un.UnitGroup(size=3)
    assert group.spec == sp.UnitSpec()
Example #16
0
def test_it_should_validate_e_rev_e_for_insane_floats(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(e_rev_e=f).validate()
Example #17
0
def test_unit_init_can_make_a_defaut_spec_for_you():
    unit = un.Unit()
    assert unit.spec == sp.UnitSpec()
Example #18
0
def test_it_should_validate_spike_gain_for_insane_floats(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.UnitSpec(spike_gain=f).validate()