def test_you_can_observe_attrs_from_the_unit_group() -> None: n = 2 ug = un.UnitGroup(size=n) for attr in ug.loggable_attrs: logs = ug.observe(attr) assert logs[0] == {"unit": 0, attr: getattr(ug, attr)[0]} assert logs[1] == {"unit": 1, attr: getattr(ug, attr)[1]}
def test_unitgroup_has_the_same_behavior_as_unit() -> None: def units_are_equal(u0: un.Unit, u1: un.Unit) -> bool: """Returns true if two units have the same state.""" attrs = ("net_raw", "net", "gc_i", "act", "i_net", "i_net_r", "v_m", "v_m_eq", "adapt", "spike") for i in attrs: assert getattr(u0, i) == getattr(u1, i) unit0 = un.Unit() unit1 = un.Unit() group = un.UnitGroup(size=2) for i in range(500): unit0.add_input(0.3) unit1.add_input(0.5) group.add_input(torch.Tensor([0.3, 0.5])) unit0.update_net() unit1.update_net() group.update_net() unit0.update_inhibition(0.1) unit1.update_inhibition(0.1) group.update_inhibition(torch.Tensor([0.1, 0.1])) unit0.update_membrane_potential() unit1.update_membrane_potential() group.update_membrane_potential() unit0.update_activation() unit1.update_activation() group.update_activation() attrs = ("net_raw", "net", "gc_i", "act", "i_net", "i_net_r", "v_m", "v_m_eq", "adapt", "spike") for i in attrs: group_attr = getattr(group, i) assert math.isclose(getattr(unit0, i), group_attr[0], abs_tol=1e-6) assert math.isclose(getattr(unit1, i), group_attr[1], abs_tol=1e-6)
def __init__(self, name: str, size: int, spec: specs.LayerSpec = None) -> None: self._name = name self.size = size if spec is None: self._spec = specs.LayerSpec() else: self._spec = spec self.units = unit.UnitGroup(size=size, spec=self.spec.unit_spec) # Feedback inhibition self.fbi = 0.0 # Global inhibition self.gc_i = 0.0 # Is the layer activation clamped? self.clamped = False # Is this a hidden layer? (i.e. has never been clamped) self.hidden = True # Set k units for inhibition self.k = max(1, int(round(self.size * self.spec.kwta_pct))) # Desired clamping values self.act_ext = torch.Tensor(self.size).zero_() # Last plus phase activation self.acts_p = torch.Tensor(self.size).zero_() # Last minus phase activation self.acts_m = torch.Tensor(self.size).zero_() # Cosine similarity between acts_p and acts_m self.cos_diff = 0.0 # Cosine similiarity between acts_p and acts_m, integrated over trials self.cos_diff_avg = 0.0 # The following two buffers are filled every time self.add_input() is # called, and reset at the end of self.activation_cycle() # Net input (excitation) input buffer. For every cycle, we # store the layer inputs here. Once we have all the inputs, we # normalize by wt_scale_rel_sum and send to the unit group. self.input_buffer = torch.Tensor(self.size).zero_() # Sum of the wt_scale_rel parameters for each projection terminating in # this layer. We use this to normalize the inputs before propagating to # unit group self.wt_scale_rel_sum = 0.0 # When adding any loggable attribute or property to these lists, update # specs.LayerSpec._valid_log_on_cycle (we represent in two places to # avoid a circular dependency) whole_attrs: List[str] = ["avg_act", "avg_net", "cos_diff_avg", "fbi"] parts_attrs: List[str] = [ "unit_net", "unit_net_raw", "unit_gc_i", "unit_act", "unit_i_net", "unit_i_net_r", "unit_v_m", "unit_v_m_eq", "unit_adapt", "unit_spike" ] super().__init__(whole_attrs=whole_attrs, parts_attrs=parts_attrs)
def test_unitgroup_can_calculate_the_threshold_inhibition() -> None: group = un.UnitGroup(size=10) group.add_input(torch.Tensor(np.linspace(0.3, 0.8, 10))) group.update_net() g_i_thr = group.g_i_thr(unit_idx=2) group.update_inhibition(torch.Tensor(10).fill_(g_i_thr)) for i in range(200): group.update_membrane_potential() assert (torch.abs(group.v_m - group.spec.spk_thr) < 1e-6)[2]
def test_unitgroup_can_return_the_top_k_net_input_values() -> None: group = un.UnitGroup(size=10) group.net = torch.Tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0]) assert (group.top_k_net_indices(3) == torch.Tensor([0, 1, 2]).long()).all()
def test_unitgroup_sets_the_spec_you_provide() -> None: spec = sp.UnitSpec() assert un.UnitGroup(size=3, spec=spec).spec is spec
def test_unitgroup_uses_the_default_spec_if_none_is_provided() -> None: group = un.UnitGroup(size=3) assert group.spec == sp.UnitSpec()
def test_it_checks_for_unobservable_attrs() -> None: ug = un.UnitGroup(3) with pytest.raises(ValueError): ug.observe("rabbit cm")
def test_unitgroup_update_inhibition_checks_input_dimensions() -> None: ug = un.UnitGroup(size=3) with pytest.raises(AssertionError): ug.update_inhibition(torch.Tensor([1, 2]))
def test_unitgroup_init_checks_that_size_is_positive() -> None: with pytest.raises(ValueError): un.UnitGroup(size=0)