def test_make_full_conn_list_returns_a_full_connection_list(): pre = lr.Layer(name="pre", size=3) post = lr.Layer(name="post", size=3) conns = pr.make_full_conn_list("proj", pre, post, sp.ConnSpec()) units = [(u, v) for u in pre.units for v in post.units] assert [c.pre for c in conns] == [u for u, _ in units] assert [c.post for c in conns] == [v for _, v in units]
def test_projn_one_to_one_connectivity_pattern_is_correct() -> None: pre = lr.Layer("lr1", size=3) post = lr.Layer("lr2", size=3) projn = pr.Projn( "proj", pre, post, sp.ProjnSpec(projn_type="one_to_one", dist=rn.Scalar(1.0))) assert (projn.wts == torch.eye(3)).all()
def test_projn_pre_mask_truncates_if_it_is_too_long() -> None: pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) spec = sp.ProjnSpec(pre_mask=(True, False), dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) assert projn.wts[0, 0] == 1 assert projn.wts.shape == (1, 1)
def test_projns_can_be_sparse() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) spec = sp.ProjnSpec(dist=rn.Scalar(1.0), sparsity=0.5) projn = pr.Projn("proj", pre, post, spec) num_on = projn.wts.sum() assert num_on == 2.0
def test_projn_can_handle_learn_events(mocker) -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) projn = pr.Projn("proj", pre, post) mocker.spy(projn, "learn") projn.handle(ev.Learn()) projn.learn.assert_called_once()
def test_you_can_log_projection_weights() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) projn = pr.Projn("proj", pre, post, spec=sp.ProjnSpec(projn_type="one_to_one", dist=rn.Scalar(0.5))) expected = {"pre_unit": [0, 1], "post_unit": [0, 1], "conn_wt": [0.5, 0.5]} assert projn.observe_parts_attr("conn_wt") == expected
def test_projn_can_inhibit_flush() -> None: pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) projn = pr.Projn("proj", pre, post) pre.clamp(act_ext=[1]) projn.inhibit() projn.flush() assert post.input_buffer == 0.0
def test_projn_inhibit_handling_event() -> None: pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) projn = pr.Projn("proj", pre, post) pre.clamp(act_ext=[1]) projn.handle(ev.InhibitProjns("proj")) projn.flush() assert post.input_buffer == 0.0
def test_projn_can_mask_post_layer_units() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) mask = (True, False) spec = sp.ProjnSpec(post_mask=mask, dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) for i in range(post.size): for j in range(pre.size): if mask[i]: assert projn.wts[i, j] == 1 else: assert projn.wts[i, j] == 0
def test_projn_pre_mask_tiles_if_it_is_too_short() -> None: pre = lr.Layer("lr1", size=4) post = lr.Layer("lr2", size=2) mask = (True, False) spec = sp.ProjnSpec(pre_mask=mask, dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) for i in range(post.size): for j in range(pre.size): if mask[j % 2]: assert projn.wts[i, j] == 1 else: assert projn.wts[i, j] == 0
def test_projn_can_uninhibit_flush() -> None: pre = lr.Layer("lr1", size=1, spec=sp.LayerSpec(clamp_max=1)) post = lr.Layer("lr2", size=1) projn = pr.Projn("proj", pre, post) pre.clamp(act_ext=[1]) projn.handle(ev.InhibitProjns("proj")) projn.flush() projn.handle(ev.UninhibitProjns("proj")) projn.flush() assert post.input_buffer == 0.5
def test_oscill_reset_inhib_event(mocker) -> None: layer = lr.Layer("lr1", 3) theta = osc.Oscill("theta1", ["lr1", "lr2"]) theta.cycle() mocker.spy(layer, "_set_kwta") layer.handle(ev.OscillEndInhibition(theta.layer_names)) layer._reset_kwta.assert_called_once()
def test_layer_change_and_reset_inhibition() -> None: layer = lr.Layer(name="in", size=10) layer._set_kwta(0.9) assert layer.k == 9 layer._reset_kwta() assert layer.k != 9
def test_layer_set_hard_clamp() -> None: layer = lr.Layer(name="in", size=3) layer.hard_clamp(act_ext=[0, 1]) layer.activation_cycle() expected = [0, 0.95, 0] for i in range(3): assert math.isclose(layer.units.act[i], expected[i], abs_tol=1e-6)
def test_layer_hard_clamping_respects_clamp_max() -> None: layer = lr.Layer(name="in", size=3, spec=sp.LayerSpec(clamp_max=0.5)) layer.hard_clamp(act_ext=[0, 1]) layer.handle(ev.HardClamp(layer_name="lr1", acts=[0, 1])) expected = [0, 0.5, 0] for i in range(len(expected)): assert math.isclose(layer.units.act[i], expected[i], abs_tol=1e-6)
def test_layer_can_add_input(n, d) -> None: layer = lr.Layer(name="in", size=d) wt_scales = np.random.uniform(low=0.0, size=(n, )) for i in range(n): layer.add_input(torch.Tensor((d)), wt_scales[i]) assert math.isclose(layer.wt_scale_rel_sum, sum(wt_scales[0:i + 1]), abs_tol=1e-6)
def test_oscill_reset_inhib_event(mocker) -> None: layer = lr.Layer("lr1", 3) theta = osc.Oscill("theta1", ["lr2"]) theta.cycle() mocker.spy(layer, "_reset_kwta") layer.handle(ev.EndOscillInhibition(theta.layer_names)) with pytest.raises(AssertionError): layer._reset_kwta.assert_called_once()
def test_projn_can_calculate_netin_scale_with_partial_connectivity( x, z, m, n, f) -> None: pre_a = lr.Layer("lr1", size=x) pre_b = lr.Layer("lr2", size=x) post = lr.Layer("lr3", size=z) spec = sp.ProjnSpec(post_mask=(True, ) * m + (False, ) * n) pre_a.hard_clamp(torch.ones(x) * f) pre_b.hard_clamp(torch.ones(x) * f) projn_a = pr.Projn("proj1", pre_a, post) projn_b = pr.Projn("proj2", pre_b, post, spec) projn_a_scale = projn_a.netin_scale() projn_b_scale = projn_b.netin_scale() assert torch.sum(projn_a_scale > projn_b_scale) == 0
def test_layer_soft_clamping_equivalent_to_input() -> None: layer1 = lr.Layer(name="test", size=3) layer2 = lr.Layer(name="alt", size=3, spec=sp.LayerSpec(clamp_max=1)) layer2.clamp([0, 1, 0], hard=False) for i in range(50): layer1.add_input(torch.Tensor([0, 1, 0])) layer1.activation_cycle() layer2.activation_cycle() print(layer1.units.act) print(layer2.units.act) for i in range(3): assert math.isclose(layer1.units.act[i], layer2.units.act[i], abs_tol=1e-6)
def test_layer_can_update_learning_averages_when_hard_clamped(mocker) -> None: layer = lr.Layer(name="layer1", size=3) mocker.spy(layer, "update_trial_learning_averages") mocker.spy(layer.units, "update_cycle_learning_averages") layer.hard_clamp([1.0]) layer.activation_cycle() layer.handle(ev.EndPlusPhase()) layer.units.update_cycle_learning_averages.assert_called_once() layer.update_trial_learning_averages.assert_called_once()
def test_projn_can_calculate_netin_scale_with_full_connectivity(x, y, z, f) -> None: pre_a = lr.Layer("lr1", size=x) pre_b = lr.Layer("lr2", size=y) post = lr.Layer("lr3", size=z) pre_a.hard_clamp(torch.ones(x) * f) pre_b.hard_clamp(torch.ones(y) * f) projn_a = pr.Projn("proj1", pre_a, post) projn_b = pr.Projn("proj2", pre_b, post) projn_a_scale = projn_a.netin_scale() projn_b_scale = projn_b.netin_scale() if x > y: compare_tensor = projn_a_scale > projn_b_scale elif x < y: compare_tensor = projn_a_scale < projn_b_scale else: compare_tensor = projn_a_scale != projn_b_scale assert torch.sum(compare_tensor) == 0
def new_layer(self, name: str, size: int, spec: specs.LayerSpec = None) -> None: """Adds a new layer to the network. Args: name: The name of the layer. size: How many units the layer should have. spec: The layer specification. Raises: spec.ValidationError: If the spec contains an invalid parameter value. """ if spec is not None: spec.validate() lr = layer.Layer(name, size, spec=spec) self.layers[name] = lr self.objs[name] = lr self._add_loggers(lr)
def new_layer(self, name: str, size: int, spec: specs.LayerSpec = None) -> None: """Adds a new layer to the network. Args: name: The name of the layer. size: How many units the layer should have. spec: The layer specification. Raises: spec.ValidationError: If the spec contains an invalid parameter value. """ if spec is not None: spec.validate() lr = layer.Layer(name, size, spec) self.layers.append(lr) self.objs[lr.name] = lr if lr.spec.log_on_cycle != (): self.cycle_loggers.append(log.Logger(lr, lr.spec.log_on_cycle))
def test_projn_has_a_name() -> None: pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) projn = pr.Projn("proj", pre, post) assert projn.name == "proj"
def test_projn_has_a_sending_layer(): pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) projn = pr.Projn("proj", pre, post) assert projn.pre == pre
def test_projn_can_flush() -> None: pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) projn = pr.Projn("proj", pre, post) projn.flush()
def test_projn_can_specify_its_weight_distribution() -> None: pre = lr.Layer("lr1", size=3) post = lr.Layer("lr2", size=3) projn = pr.Projn("proj", pre, post, sp.ProjnSpec(dist=rn.Scalar(7))) assert (projn.wts == 7).all()
def test_projn_has_a_receiving_layer() -> None: pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) projn = pr.Projn("proj", pre, post) assert projn.post == post
def test_projn_can_learn() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) projn = pr.Projn("proj", pre, post) projn.learn()
def test_observing_invalid_parts_attr_raises_value_error() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) projn = pr.Projn("proj", pre, post) with pytest.raises(ValueError): projn.observe_parts_attr("whales")