예제 #1
0
def test_kwta_suppresses_all_but_k_units() -> None:
    n = net.Net()

    n.new_layer(name="lr1", size=1)

    lr2_spec = sp.LayerSpec(
        inhibition_type="kwta",
        k=2,
        log_on_cycle=("unit_act", ),
        unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0))
    n.new_layer(name="lr2", size=3, spec=lr2_spec)

    pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3))
    n.new_projn("proj1", "lr1", "lr2", pr1_spec)

    pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1])
    n.new_projn("proj2", "lr1", "lr2", pr2_spec)

    n.force_layer("lr1", [1])
    for i in range(100):
        n.cycle()

    logs = n.logs("cycle", "lr2")
    acts = logs[logs.time == 99]["act"]
    assert (acts > 0.8).sum() == 2
예제 #2
0
    def __init__(self,
                 name: str,
                 pre: layer.Layer,
                 post: layer.Layer,
                 spec: specs.ProjnSpec = None) -> None:
        self.name = name
        self.pre = pre
        self.post = post

        if spec is None:
            self.spec = specs.ProjnSpec()
        else:
            self.spec = spec

        # A matrix where each element is the weight of a connection.
        # Rows encode the postsynaptic units, and columns encode the
        # presynaptic units
        self.wts = torch.Tensor(self.post.size, self.pre.size).zero_()

        # Only create the projection between the units selected by the masks
        # Currently, only full connections are supported
        # TODO: Refactor mask expansion and creation into new methods + test
        tiled_pre_mask = tile(self.pre.size, self.spec.pre_mask)
        tiled_post_mask = tile(self.post.size, self.spec.post_mask)
        mask = expand_layer_mask_full(tiled_pre_mask, tiled_post_mask)

        # Enforce sparsity
        # TODO: Make this a separate method
        mask, num_nonzero = sparsify(self.spec.sparsity, mask)

        # Fill the weight matrix with values
        rand_nums = torch.Tensor(num_nonzero)
        self.spec.dist.fill(rand_nums)
        self.wts[mask] = rand_nums
예제 #3
0
def test_projn_one_to_one_connectivity_pattern_is_correct() -> None:
    pre = lr.Layer("lr1", size=3)
    post = lr.Layer("lr2", size=3)
    projn = pr.Projn(
        "proj", pre, post,
        sp.ProjnSpec(projn_type="one_to_one", dist=rn.Scalar(1.0)))
    assert (projn.wts == torch.eye(3)).all()
예제 #4
0
    def __init__(self,
                 name: str,
                 pre: layer.Layer,
                 post: layer.Layer,
                 spec: specs.ProjnSpec = None) -> None:
        self._name = name
        self.pre = pre
        self.post = post

        self.cos_diff = 0.0
        self.cos_diff_avg = 0.0
        self.blocked = False

        if spec is None:
            self._spec = specs.ProjnSpec()
        else:
            self._spec = spec

        self.minus_phase = self._spec.minus_phase
        self.plus_phase = self._spec.plus_phase

        # A matrix where each element is the weight of a connection.
        # Rows encode the postsynaptic units, and columns encode the
        # presynaptic units. These weights are sigmoidally contrast-enchanced,
        # and are used to send net input to other neurons.
        self.wts = torch.Tensor(self.post.size, self.pre.size).zero_()
        # These weights ("fast weights") are linear and not contrast enhanced
        self.fwts = torch.Tensor(self.post.size, self.pre.size).zero_()

        # Only create the projection between the units selected by the masks
        # Currently, only full connections are supported
        tiled_pre_mask = tile(self.pre.size, self.spec.pre_mask)
        tiled_post_mask = tile(self.post.size, self.spec.post_mask)

        if self.spec.projn_type == "one_to_one":
            mask = expand_layer_mask_one_to_one(tiled_pre_mask,
                                                tiled_post_mask)
        elif self.spec.projn_type == "full":
            mask = expand_layer_mask_full(tiled_pre_mask, tiled_post_mask)

        # Enforce sparsity
        self.mask, num_nonzero = sparsify(self.spec.sparsity, mask)

        # Fill the weight matrix with values
        rand_nums = torch.Tensor(num_nonzero)
        self.spec.dist.fill(rand_nums)
        self.wts[self.mask] = rand_nums

        self.fwts = self.wts

        # Record the number of incoming connections for each unit
        self.num_recv_conns = torch.sum(self.mask, dim=1).float()

        # When adding any loggable attribute or property to these lists, update
        # specs.ProjnSpec._valid_log_on_cycle (we represent in two places to
        # avoid a circular dependency)
        whole_attrs: List[str] = ["cos_diff_avg"]
        parts_attrs: List[str] = ["conn_wt", "conn_fwt"]

        super().__init__(whole_attrs=whole_attrs, parts_attrs=parts_attrs)
예제 #5
0
def test_projns_can_be_sparse() -> None:
    pre = lr.Layer("lr1", size=2)
    post = lr.Layer("lr2", size=2)
    spec = sp.ProjnSpec(dist=rn.Scalar(1.0), sparsity=0.5)
    projn = pr.Projn("proj", pre, post, spec)
    num_on = projn.wts.sum()
    assert num_on == 2.0
예제 #6
0
def test_projn_pre_mask_truncates_if_it_is_too_long() -> None:
    pre = lr.Layer("lr1", size=1)
    post = lr.Layer("lr2", size=1)
    spec = sp.ProjnSpec(pre_mask=(True, False), dist=rn.Scalar(1))
    projn = pr.Projn("proj", pre, post, spec)
    assert projn.wts[0, 0] == 1
    assert projn.wts.shape == (1, 1)
예제 #7
0
def test_a_new_projn_validates_its_spec() -> None:
    n = net.Net()
    n.new_layer("layer1", 3)
    n.new_layer("layer2", 3)
    with pytest.raises(specs.ValidationError):
        n.new_projn(
            "projn1", "layer1", "layer2", spec=specs.ProjnSpec(integ=-1))
예제 #8
0
def test_you_can_log_projection_weights() -> None:
    pre = lr.Layer("lr1", size=2)
    post = lr.Layer("lr2", size=2)
    projn = pr.Projn("proj",
                     pre,
                     post,
                     spec=sp.ProjnSpec(projn_type="one_to_one",
                                       dist=rn.Scalar(0.5)))
    expected = {"pre_unit": [0, 1], "post_unit": [0, 1], "conn_wt": [0.5, 0.5]}
    assert projn.observe_parts_attr("conn_wt") == expected
예제 #9
0
def test_projn_can_mask_post_layer_units() -> None:
    pre = lr.Layer("lr1", size=2)
    post = lr.Layer("lr2", size=2)
    mask = (True, False)
    spec = sp.ProjnSpec(post_mask=mask, dist=rn.Scalar(1))
    projn = pr.Projn("proj", pre, post, spec)
    for i in range(post.size):
        for j in range(pre.size):
            if mask[i]:
                assert projn.wts[i, j] == 1
            else:
                assert projn.wts[i, j] == 0
예제 #10
0
def test_projn_pre_mask_tiles_if_it_is_too_short() -> None:
    pre = lr.Layer("lr1", size=4)
    post = lr.Layer("lr2", size=2)
    mask = (True, False)
    spec = sp.ProjnSpec(pre_mask=mask, dist=rn.Scalar(1))
    projn = pr.Projn("proj", pre, post, spec)
    for i in range(post.size):
        for j in range(pre.size):
            if mask[j % 2]:
                assert projn.wts[i, j] == 1
            else:
                assert projn.wts[i, j] == 0
예제 #11
0
    def __init__(self,
                 name: str,
                 pre: layer.Layer,
                 post: layer.Layer,
                 spec: specs.ProjnSpec = None) -> None:
        self.name = name
        self.pre = pre
        self.post = post

        if spec is None:
            self.spec = specs.ProjnSpec()
        else:
            self.spec = spec

        self.conns = make_full_conn_list(name, pre, post, specs.ConnSpec())
예제 #12
0
def test_projn_can_calculate_netin_scale_with_partial_connectivity(
        x, z, m, n, f) -> None:

    pre_a = lr.Layer("lr1", size=x)
    pre_b = lr.Layer("lr2", size=x)
    post = lr.Layer("lr3", size=z)

    spec = sp.ProjnSpec(post_mask=(True, ) * m + (False, ) * n)

    pre_a.hard_clamp(torch.ones(x) * f)
    pre_b.hard_clamp(torch.ones(x) * f)

    projn_a = pr.Projn("proj1", pre_a, post)
    projn_b = pr.Projn("proj2", pre_b, post, spec)

    projn_a_scale = projn_a.netin_scale()
    projn_b_scale = projn_b.netin_scale()

    assert torch.sum(projn_a_scale > projn_b_scale) == 0
예제 #13
0
def test_projn_spec_validates_different_plus_and_minus_phases() -> None:
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(minus_phase=ev.NonePhase,
                     plus_phase=ev.NonePhase).validate()
예제 #14
0
def test_projn_spec_validates_integ(f):
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(integ=f).validate()
예제 #15
0
def test_projn_can_specify_its_weight_distribution() -> None:
    pre = lr.Layer("lr1", size=3)
    post = lr.Layer("lr2", size=3)
    projn = pr.Projn("proj", pre, post, sp.ProjnSpec(dist=rn.Scalar(7)))
    assert (projn.wts == 7).all()
예제 #16
0
def test_projn_spec_validates_the_distribution() -> None:
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(dist=3).validate()
예제 #17
0
def test_projn_spec_validates_projn_type(f) -> None:
    if f not in ["one_to_one", "full"]:
        with pytest.raises(sp.ValidationError):
            sp.ProjnSpec(projn_type=f).validate()
    else:
        sp.ProjnSpec(projn_type=f).validate()
예제 #18
0
def test_projn_spec_validates_sparsity(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(sparsity=f).validate()
예제 #19
0
def test_projn_spec_validates_attrs_to_log() -> None:
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(log_on_cycle=("whales", )).validate()
예제 #20
0
def test_projn_spec_validates_cos_diff_thr_l_mix(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(cos_diff_thr_l_mix=f).validate()
예제 #21
0
def test_projn_spec_validates_wt_scale_rel(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(wt_scale_rel=f).validate()
예제 #22
0
def test_projn_spec_validates_sig_offset(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(sig_offset=f).validate()
예제 #23
0
def test_projn_spec_validates_lrate(f) -> None:
    with pytest.raises(sp.ValidationError):
        sp.ProjnSpec(lrate=f).validate()