def test_unitgroup_has_the_same_behavior_as_unit() -> None: def units_are_equal(u0: un.Unit, u1: un.Unit) -> bool: """Returns true if two units have the same state.""" attrs = ("net_raw", "net", "gc_i", "act", "i_net", "i_net_r", "v_m", "v_m_eq", "adapt", "spike") for i in attrs: assert getattr(u0, i) == getattr(u1, i) unit0 = un.Unit() unit1 = un.Unit() group = un.UnitGroup(size=2) for i in range(500): unit0.add_input(0.3) unit1.add_input(0.5) group.add_input(torch.Tensor([0.3, 0.5])) unit0.update_net() unit1.update_net() group.update_net() unit0.update_inhibition(0.1) unit1.update_inhibition(0.1) group.update_inhibition(torch.Tensor([0.1, 0.1])) unit0.update_membrane_potential() unit1.update_membrane_potential() group.update_membrane_potential() unit0.update_activation() unit1.update_activation() group.update_activation() attrs = ("net_raw", "net", "gc_i", "act", "i_net", "i_net_r", "v_m", "v_m_eq", "adapt", "spike") for i in attrs: group_attr = getattr(group, i) assert math.isclose(getattr(unit0, i), group_attr[0], abs_tol=1e-6) assert math.isclose(getattr(unit1, i), group_attr[1], abs_tol=1e-6)
def test_unit_can_calculate_the_inhibition_to_put_it_at_threshold() -> None: unit = un.Unit() unit.add_input(3) unit.update_net() unit.update_inhibition(unit.g_i_thr()) for i in range(200): unit.update_membrane_potential() assert math.isclose(unit.v_m, unit.spec.spk_thr, rel_tol=1e-4)
def __init__(self, name: str, size: int, spec: specs.LayerSpec = None) -> None: self.size = size if spec is None: self.spec = specs.LayerSpec() else: self.spec = spec self.units = [unit.Unit(self.spec.unit_spec) for _ in range(size)] self.fbi = 0.0 super().__init__(name)
def test_nxx1_equals_the_max_table_value_outside_the_max_boundary(nxx1_table): unit = un.Unit() xs, conv = nxx1_table assert unit.nxx1(xs[-1] + 1) == conv[-1]
def test_nxx1_equals_the_lookup_table(nxx1_table): unit = un.Unit() xs, conv = nxx1_table for i in range(0, xs.size, 50): assert math.isclose(unit.nxx1(xs[i]), conv[i])
def test_unit_raises_valueerror_if_attr_is_unobservable(): unit = un.Unit() with pytest.raises(ValueError): unit.observe("banannas")
def test_unit_can_observe_its_attributes(): unit = un.Unit() assert unit.observe("act") == [("act", 0.0)]
def test_unit_can_update_its_membrane_potential(): unit = un.Unit() unit.update_membrane_potential()
def test_unit_has_0_raw_un_input_at_first(): unit = un.Unit() assert unit.net_raw == 0
def test_projn_has_a_sending_unit(): pre = un.Unit() post = un.Unit() conn = pr.Conn("con1", pre, post) assert conn.pre == pre
def test_conn_has_a_name(): pre = un.Unit() post = un.Unit() conn = pr.Conn("con1", pre, post) assert conn.name == "con1"
def test_unit_init_uses_the_spec_you_pass_it() -> None: foo = sp.UnitSpec() unit = un.Unit(spec=foo) assert unit.spec is foo
def test_nxx1_equals_the_min_value_outside_the_min_bound(nxx1_table) -> None: unit = un.Unit() xs, conv = nxx1_table assert unit.nxx1(xs[0] - 1) == conv[0]
def test_unit_init_uses_the_spec_you_pass_it(): unit = un.Unit(spec=3) assert unit.spec == 3
def test_conn_has_a_receiving_unit(): pre = un.Unit() post = un.Unit() conn = pr.Conn("con1", pre, post) assert conn.post == post
def test_unit_init_can_make_a_defaut_spec_for_you(): unit = un.Unit() assert unit.spec == sp.UnitSpec()
def test_conn_has_a_weight(): pre = un.Unit() post = un.Unit() conn = pr.Conn("con1", pre, post) assert conn.wt > 0
def test_unit_can_add_inputs_to_the_raw_un_input(): unit = un.Unit() unit.add_input(3) assert unit.net_raw == 3
def test_conn_init_uses_the_spec_you_pass_it(): spec = sp.ConnSpec() pre = un.Unit() post = un.Unit() conn = pr.Conn("proj1", pre, post, spec=spec) assert conn.spec is spec
def test_unit_can_update_its_activation(): unit = un.Unit() unit.update_activation()
def test_unit_can_observe_its_attributes() -> None: unit = un.Unit() assert unit.observe("act") == {"act": 0.0}