def test_kwta_suppresses_all_but_k_units() -> None: n = net.Net() n.new_layer(name="lr1", size=1) lr2_spec = sp.LayerSpec( inhibition_type="kwta", k=2, log_on_cycle=("unit_act", ), unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0)) n.new_layer(name="lr2", size=3, spec=lr2_spec) pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3)) n.new_projn("proj1", "lr1", "lr2", pr1_spec) pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1]) n.new_projn("proj2", "lr1", "lr2", pr2_spec) n.force_layer("lr1", [1]) for i in range(100): n.cycle() logs = n.logs("cycle", "lr2") acts = logs[logs.time == 99]["act"] assert (acts > 0.8).sum() == 2
def __init__(self, name: str, pre: layer.Layer, post: layer.Layer, spec: specs.ProjnSpec = None) -> None: self.name = name self.pre = pre self.post = post if spec is None: self.spec = specs.ProjnSpec() else: self.spec = spec # A matrix where each element is the weight of a connection. # Rows encode the postsynaptic units, and columns encode the # presynaptic units self.wts = torch.Tensor(self.post.size, self.pre.size).zero_() # Only create the projection between the units selected by the masks # Currently, only full connections are supported # TODO: Refactor mask expansion and creation into new methods + test tiled_pre_mask = tile(self.pre.size, self.spec.pre_mask) tiled_post_mask = tile(self.post.size, self.spec.post_mask) mask = expand_layer_mask_full(tiled_pre_mask, tiled_post_mask) # Enforce sparsity # TODO: Make this a separate method mask, num_nonzero = sparsify(self.spec.sparsity, mask) # Fill the weight matrix with values rand_nums = torch.Tensor(num_nonzero) self.spec.dist.fill(rand_nums) self.wts[mask] = rand_nums
def test_projn_one_to_one_connectivity_pattern_is_correct() -> None: pre = lr.Layer("lr1", size=3) post = lr.Layer("lr2", size=3) projn = pr.Projn( "proj", pre, post, sp.ProjnSpec(projn_type="one_to_one", dist=rn.Scalar(1.0))) assert (projn.wts == torch.eye(3)).all()
def __init__(self, name: str, pre: layer.Layer, post: layer.Layer, spec: specs.ProjnSpec = None) -> None: self._name = name self.pre = pre self.post = post self.cos_diff = 0.0 self.cos_diff_avg = 0.0 self.blocked = False if spec is None: self._spec = specs.ProjnSpec() else: self._spec = spec self.minus_phase = self._spec.minus_phase self.plus_phase = self._spec.plus_phase # A matrix where each element is the weight of a connection. # Rows encode the postsynaptic units, and columns encode the # presynaptic units. These weights are sigmoidally contrast-enchanced, # and are used to send net input to other neurons. self.wts = torch.Tensor(self.post.size, self.pre.size).zero_() # These weights ("fast weights") are linear and not contrast enhanced self.fwts = torch.Tensor(self.post.size, self.pre.size).zero_() # Only create the projection between the units selected by the masks # Currently, only full connections are supported tiled_pre_mask = tile(self.pre.size, self.spec.pre_mask) tiled_post_mask = tile(self.post.size, self.spec.post_mask) if self.spec.projn_type == "one_to_one": mask = expand_layer_mask_one_to_one(tiled_pre_mask, tiled_post_mask) elif self.spec.projn_type == "full": mask = expand_layer_mask_full(tiled_pre_mask, tiled_post_mask) # Enforce sparsity self.mask, num_nonzero = sparsify(self.spec.sparsity, mask) # Fill the weight matrix with values rand_nums = torch.Tensor(num_nonzero) self.spec.dist.fill(rand_nums) self.wts[self.mask] = rand_nums self.fwts = self.wts # Record the number of incoming connections for each unit self.num_recv_conns = torch.sum(self.mask, dim=1).float() # When adding any loggable attribute or property to these lists, update # specs.ProjnSpec._valid_log_on_cycle (we represent in two places to # avoid a circular dependency) whole_attrs: List[str] = ["cos_diff_avg"] parts_attrs: List[str] = ["conn_wt", "conn_fwt"] super().__init__(whole_attrs=whole_attrs, parts_attrs=parts_attrs)
def test_projns_can_be_sparse() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) spec = sp.ProjnSpec(dist=rn.Scalar(1.0), sparsity=0.5) projn = pr.Projn("proj", pre, post, spec) num_on = projn.wts.sum() assert num_on == 2.0
def test_projn_pre_mask_truncates_if_it_is_too_long() -> None: pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) spec = sp.ProjnSpec(pre_mask=(True, False), dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) assert projn.wts[0, 0] == 1 assert projn.wts.shape == (1, 1)
def test_a_new_projn_validates_its_spec() -> None: n = net.Net() n.new_layer("layer1", 3) n.new_layer("layer2", 3) with pytest.raises(specs.ValidationError): n.new_projn( "projn1", "layer1", "layer2", spec=specs.ProjnSpec(integ=-1))
def test_you_can_log_projection_weights() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) projn = pr.Projn("proj", pre, post, spec=sp.ProjnSpec(projn_type="one_to_one", dist=rn.Scalar(0.5))) expected = {"pre_unit": [0, 1], "post_unit": [0, 1], "conn_wt": [0.5, 0.5]} assert projn.observe_parts_attr("conn_wt") == expected
def test_projn_can_mask_post_layer_units() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) mask = (True, False) spec = sp.ProjnSpec(post_mask=mask, dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) for i in range(post.size): for j in range(pre.size): if mask[i]: assert projn.wts[i, j] == 1 else: assert projn.wts[i, j] == 0
def test_projn_pre_mask_tiles_if_it_is_too_short() -> None: pre = lr.Layer("lr1", size=4) post = lr.Layer("lr2", size=2) mask = (True, False) spec = sp.ProjnSpec(pre_mask=mask, dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) for i in range(post.size): for j in range(pre.size): if mask[j % 2]: assert projn.wts[i, j] == 1 else: assert projn.wts[i, j] == 0
def __init__(self, name: str, pre: layer.Layer, post: layer.Layer, spec: specs.ProjnSpec = None) -> None: self.name = name self.pre = pre self.post = post if spec is None: self.spec = specs.ProjnSpec() else: self.spec = spec self.conns = make_full_conn_list(name, pre, post, specs.ConnSpec())
def test_projn_can_calculate_netin_scale_with_partial_connectivity( x, z, m, n, f) -> None: pre_a = lr.Layer("lr1", size=x) pre_b = lr.Layer("lr2", size=x) post = lr.Layer("lr3", size=z) spec = sp.ProjnSpec(post_mask=(True, ) * m + (False, ) * n) pre_a.hard_clamp(torch.ones(x) * f) pre_b.hard_clamp(torch.ones(x) * f) projn_a = pr.Projn("proj1", pre_a, post) projn_b = pr.Projn("proj2", pre_b, post, spec) projn_a_scale = projn_a.netin_scale() projn_b_scale = projn_b.netin_scale() assert torch.sum(projn_a_scale > projn_b_scale) == 0
def test_projn_spec_validates_different_plus_and_minus_phases() -> None: with pytest.raises(sp.ValidationError): sp.ProjnSpec(minus_phase=ev.NonePhase, plus_phase=ev.NonePhase).validate()
def test_projn_spec_validates_integ(f): with pytest.raises(sp.ValidationError): sp.ProjnSpec(integ=f).validate()
def test_projn_can_specify_its_weight_distribution() -> None: pre = lr.Layer("lr1", size=3) post = lr.Layer("lr2", size=3) projn = pr.Projn("proj", pre, post, sp.ProjnSpec(dist=rn.Scalar(7))) assert (projn.wts == 7).all()
def test_projn_spec_validates_the_distribution() -> None: with pytest.raises(sp.ValidationError): sp.ProjnSpec(dist=3).validate()
def test_projn_spec_validates_projn_type(f) -> None: if f not in ["one_to_one", "full"]: with pytest.raises(sp.ValidationError): sp.ProjnSpec(projn_type=f).validate() else: sp.ProjnSpec(projn_type=f).validate()
def test_projn_spec_validates_sparsity(f) -> None: with pytest.raises(sp.ValidationError): sp.ProjnSpec(sparsity=f).validate()
def test_projn_spec_validates_attrs_to_log() -> None: with pytest.raises(sp.ValidationError): sp.ProjnSpec(log_on_cycle=("whales", )).validate()
def test_projn_spec_validates_cos_diff_thr_l_mix(f) -> None: with pytest.raises(sp.ValidationError): sp.ProjnSpec(cos_diff_thr_l_mix=f).validate()
def test_projn_spec_validates_wt_scale_rel(f) -> None: with pytest.raises(sp.ValidationError): sp.ProjnSpec(wt_scale_rel=f).validate()
def test_projn_spec_validates_sig_offset(f) -> None: with pytest.raises(sp.ValidationError): sp.ProjnSpec(sig_offset=f).validate()
def test_projn_spec_validates_lrate(f) -> None: with pytest.raises(sp.ValidationError): sp.ProjnSpec(lrate=f).validate()