def __init__(self, spec: specs.UnitSpec = None) -> None: if spec is None: self.spec = specs.UnitSpec() else: self.spec = spec # Net input (excitation) without time integration self.net_raw = 0.0 # Net inpput (excitation) with time integration self.net = 0.0 # Total (feedback + feedforward) inhibition self.gc_i = 0.0 # Activation self.act = 0.0 # Net current self.i_net = 0.0 # Net current, rate-coded (driven by v_m_eq) self.i_net_r = 0.0 # Membrane potential self.v_m = 0.0 # Equilibrium membrane potential (does not reset on spike) self.v_m_eq = 0.0 # Adaption current self.adapt = 0.0 # Are we spiking? (0 or 1) self.spike = 0
def test_kwta_suppresses_all_but_k_units() -> None: n = net.Net() n.new_layer(name="lr1", size=1) lr2_spec = sp.LayerSpec( inhibition_type="kwta", k=2, log_on_cycle=("unit_act", ), unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0)) n.new_layer(name="lr2", size=3, spec=lr2_spec) pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3)) n.new_projn("proj1", "lr1", "lr2", pr1_spec) pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1]) n.new_projn("proj2", "lr1", "lr2", pr2_spec) n.force_layer("lr1", [1]) for i in range(100): n.cycle() logs = n.logs("cycle", "lr2") acts = logs[logs.time == 99]["act"] assert (acts > 0.8).sum() == 2
def __init__(self, size: int, spec: specs.UnitSpec = None) -> None: if size <= 0: raise ValueError("size must be greater than zero.") self.size = size if spec is None: self.spec = specs.UnitSpec() else: self.spec = spec # When adding any attribute to this class, update # layer.LayerSpec._valid_log_on_cycle # Net input (excitation) without time integration self.net_raw = torch.Tensor(self.size).zero_() # Net inpput (excitation) with time integration self.net = torch.Tensor(self.size).zero_() # Total (feedback + feedforward) inhibition self.gc_i = torch.Tensor(self.size).zero_() # Activation self.act = torch.Tensor(self.size).zero_() # Net current self.i_net = torch.Tensor(self.size).zero_() # Net current, rate-coded (driven by v_m_eq) self.i_net_r = torch.Tensor(self.size).zero_() # Membrane potential self.v_m = torch.Tensor(self.size).zero_() # Equilibrium membrane potential (does not reset on spike) self.v_m_eq = torch.Tensor(self.size).zero_() # Adaption current self.adapt = torch.Tensor(self.size).zero_() # Are we spiking? (0 or 1) # In the future, this could be a ByteTensor self.spike = torch.Tensor(self.size).zero_()
def test_unit_spec_validates_integ(f) -> None: with pytest.raises(sp.ValidationError): sp.UnitSpec(integ=f).validate()
def test_it_should_validate_the_unit_spec() -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(unit_spec=sp.UnitSpec(vm_dt=-1)).validate()
def test_it_should_validate_l_up_inc(f) -> None: with pytest.raises(sp.ValidationError): sp.UnitSpec(l_up_inc=f).validate()
def test_it_should_validate_m_dt(f) -> None: with pytest.raises(sp.ValidationError): sp.UnitSpec(m_dt=f).validate()
def test_it_should_validate_act_gain_positive(f) -> None: with pytest.raises(sp.ValidationError): sp.UnitSpec(act_gain=f).validate()
def test_it_should_validate_syn_tr(f) -> None: with pytest.raises(sp.ValidationError): sp.UnitSpec(syn_tr=f).validate()
def test_it_should_validate_v_m_r_is_less_than_spk_thr() -> None: with pytest.raises(sp.ValidationError): sp.UnitSpec(v_m_r=1, spk_thr=0.5).validate()
def test_it_should_validate_adapt_dt(f): with pytest.raises(sp.ValidationError): sp.UnitSpec(adapt_dt=f).validate()
def test_it_should_validate_v_m_r_for_insane_floats(f): with pytest.raises(sp.ValidationError): sp.UnitSpec(v_m_r=f).validate()
def test_unit_init_uses_the_spec_you_pass_it() -> None: foo = sp.UnitSpec() unit = un.Unit(spec=foo) assert unit.spec is foo
def test_unitgroup_sets_the_spec_you_provide() -> None: spec = sp.UnitSpec() assert un.UnitGroup(size=3, spec=spec).spec is spec
def test_unitgroup_uses_the_default_spec_if_none_is_provided() -> None: group = un.UnitGroup(size=3) assert group.spec == sp.UnitSpec()
def test_it_should_validate_e_rev_e_for_insane_floats(f) -> None: with pytest.raises(sp.ValidationError): sp.UnitSpec(e_rev_e=f).validate()
def test_unit_init_can_make_a_defaut_spec_for_you(): unit = un.Unit() assert unit.spec == sp.UnitSpec()
def test_it_should_validate_spike_gain_for_insane_floats(f) -> None: with pytest.raises(sp.ValidationError): sp.UnitSpec(spike_gain=f).validate()