Ejemplo n.º 1
0
def test_kwta_suppresses_all_but_k_units() -> None:
    n = net.Net()

    n.new_layer(name="lr1", size=1)

    lr2_spec = sp.LayerSpec(
        inhibition_type="kwta",
        k=2,
        log_on_cycle=("unit_act", ),
        unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0))
    n.new_layer(name="lr2", size=3, spec=lr2_spec)

    pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3))
    n.new_projn("proj1", "lr1", "lr2", pr1_spec)

    pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1])
    n.new_projn("proj2", "lr1", "lr2", pr2_spec)

    n.force_layer("lr1", [1])
    for i in range(100):
        n.cycle()

    logs = n.logs("cycle", "lr2")
    acts = logs[logs.time == 99]["act"]
    assert (acts > 0.8).sum() == 2
Ejemplo n.º 2
0
def test_projn_pre_mask_truncates_if_it_is_too_long() -> None:
    pre = lr.Layer("lr1", size=1)
    post = lr.Layer("lr2", size=1)
    spec = sp.ProjnSpec(pre_mask=(True, False), dist=rn.Scalar(1))
    projn = pr.Projn("proj", pre, post, spec)
    assert projn.wts[0, 0] == 1
    assert projn.wts.shape == (1, 1)
Ejemplo n.º 3
0
def test_projns_can_be_sparse() -> None:
    pre = lr.Layer("lr1", size=2)
    post = lr.Layer("lr2", size=2)
    spec = sp.ProjnSpec(dist=rn.Scalar(1.0), sparsity=0.5)
    projn = pr.Projn("proj", pre, post, spec)
    num_on = projn.wts.sum()
    assert num_on == 2.0
Ejemplo n.º 4
0
def test_projn_one_to_one_connectivity_pattern_is_correct() -> None:
    pre = lr.Layer("lr1", size=3)
    post = lr.Layer("lr2", size=3)
    projn = pr.Projn(
        "proj", pre, post,
        sp.ProjnSpec(projn_type="one_to_one", dist=rn.Scalar(1.0)))
    assert (projn.wts == torch.eye(3)).all()
Ejemplo n.º 5
0
def test_you_can_log_projection_weights() -> None:
    pre = lr.Layer("lr1", size=2)
    post = lr.Layer("lr2", size=2)
    projn = pr.Projn("proj",
                     pre,
                     post,
                     spec=sp.ProjnSpec(projn_type="one_to_one",
                                       dist=rn.Scalar(0.5)))
    expected = {"pre_unit": [0, 1], "post_unit": [0, 1], "conn_wt": [0.5, 0.5]}
    assert projn.observe_parts_attr("conn_wt") == expected
Ejemplo n.º 6
0
def test_projn_can_mask_post_layer_units() -> None:
    pre = lr.Layer("lr1", size=2)
    post = lr.Layer("lr2", size=2)
    mask = (True, False)
    spec = sp.ProjnSpec(post_mask=mask, dist=rn.Scalar(1))
    projn = pr.Projn("proj", pre, post, spec)
    for i in range(post.size):
        for j in range(pre.size):
            if mask[i]:
                assert projn.wts[i, j] == 1
            else:
                assert projn.wts[i, j] == 0
Ejemplo n.º 7
0
def test_projn_pre_mask_tiles_if_it_is_too_short() -> None:
    pre = lr.Layer("lr1", size=4)
    post = lr.Layer("lr2", size=2)
    mask = (True, False)
    spec = sp.ProjnSpec(pre_mask=mask, dist=rn.Scalar(1))
    projn = pr.Projn("proj", pre, post, spec)
    for i in range(post.size):
        for j in range(pre.size):
            if mask[j % 2]:
                assert projn.wts[i, j] == 1
            else:
                assert projn.wts[i, j] == 0
Ejemplo n.º 8
0
class ProjnSpec(Spec):
    """Spec for `Projn` objects."""
    # The probability distribution from which the connection weights will be
    # drawn
    dist: rand.Distribution = rand.Scalar(0.5)
    # Selects which pre layer units will be included in the projection
    # If the length is less than the number of units in the pre_layer, it will
    # be tiled. If the length is more, it will be truncated.
    pre_mask: Iterable[bool] = (True, )
    # Selects which post layer units will be included in the projection
    # If the length is less than the number of units in the pre_layer, it will
    # be tiled. If the length is more, it will be truncated.
    post_mask: Iterable[bool] = (True, )
    # Sparsity of the connection (i.e. the percentage of active connections.)
    sparsity: float = 1.0

    def validate(self) -> None:  # pylint: disable=W0235
        """Extends `Spec.validate`."""
        if not isinstance(self.dist, rand.Distribution):
            raise ValidationError("{0} is not a valid "
                                  "distribution.".format(self.dist))
        self.assert_in_range("sparsity", low=0.0, high=1.0)
        super().validate()
Ejemplo n.º 9
0
def test_projn_can_specify_its_weight_distribution() -> None:
    pre = lr.Layer("lr1", size=3)
    post = lr.Layer("lr2", size=3)
    projn = pr.Projn("proj", pre, post, sp.ProjnSpec(dist=rn.Scalar(7)))
    assert (projn.wts == 7).all()
Ejemplo n.º 10
0
class ProjnSpec(ObservableSpec):
    """Spec for `Projn` objects."""
    # The probability distribution from which the connection weights will be
    # drawn
    dist: rand.Distribution = rand.Scalar(0.5)
    # Selects which pre layer units will be included in the projection
    # If the length is less than the number of units in the pre_layer, it will
    # be tiled. If the length is more, it will be truncated.
    pre_mask: Iterable[bool] = (True, )
    # Selects which post layer units will be included in the projection
    # If the length is less than the number of units in the pre_layer, it will
    # be tiled. If the length is more, it will be truncated.
    post_mask: Iterable[bool] = (True, )
    # Sparsity of the connection (i.e. the percentage of active connections.)
    sparsity: float = 1.0
    # Set special type of projection. One of ["full", "one_to_one"].
    projn_type = "full"
    # Absolute net input scaling weight
    wt_scale_abs: float = 1.0
    # Relative net input scaling weight (relative to other projections
    # terminating in the same layer)
    wt_scale_rel: float = 1.0
    # Learning rate
    lrate = 0.02
    # Mixing constant determining how much learning is hebbian.
    # See Emergent docs.
    thr_l_mix = 0.1
    # Flag controlling whether thr_l_mix is modulated by cos_diff_avg
    cos_diff_thr_l_mix = False
    # Modulate the learn rate by cos_diff_avg?
    cos_diff_lrate = False
    # Gain for sigmoidal weight contrast enhancement
    sig_gain = 6
    # Offset for sigmoidal weight contrast enhancement
    sig_offset = 1
    # Minus phase
    minus_phase = events.MinusPhase
    # Plus phase
    plus_phase = events.PlusPhase

    @property
    def _valid_attrs_to_log(self) -> Iterable[str]:
        """Overrides `ObservableSpec._valid_attrs_to_log`."""
        # Valid attributes to log on every cycle
        # When adding any loggable attribute or property to this list,
        # update Projn._whole_attrs or Projn._parts_attrs as appropriate
        # (we represent in two places to avoid a circular dependency)
        return ("conn_wt", "conn_fwt", "cos_diff_avg")

    def validate(self) -> None:  # pylint: disable=W0235
        """Extends `Spec.validate`."""
        super().validate()

        if not isinstance(self.dist, rand.Distribution):
            raise ValidationError("{0} is not a valid "
                                  "distribution.".format(self.dist))
        self.assert_in_range("sparsity", low=0.0, high=1.0)

        valid_projn_types = ["one_to_one", "full"]
        if self.projn_type not in valid_projn_types:
            raise ValidationError(
                "Projn type {0} not one of [\"one_to_one\", \"full\"]".format(
                    self.projn_type))

        self.assert_in_range("wt_scale_abs", 0, float("Inf"))
        self.assert_in_range("wt_scale_rel", 0, float("Inf"))
        self.assert_in_range("lrate", 0, float("Inf"))
        self.assert_in_range("thr_l_mix", 0, float("Inf"))
        self.assert_in_range("cos_diff_thr_l_mix", 0, float("Inf"))
        self.assert_in_range("sig_gain", 0, float("Inf"))
        self.assert_sane_float("sig_offset")

        if self.minus_phase == self.plus_phase:
            raise ValidationError(
                "Minus and plus phase cannot both be {0}".format(
                    self.minus_phase))
Ejemplo n.º 11
0
def test_scalar_is_always_equal_to_its_value() -> None:
    dist = rn.Scalar(3)
    x = torch.Tensor(10)
    dist.fill(x)
    assert (x == 3).all()