def test_kwta_suppresses_all_but_k_units() -> None: n = net.Net() n.new_layer(name="lr1", size=1) lr2_spec = sp.LayerSpec( inhibition_type="kwta", k=2, log_on_cycle=("unit_act", ), unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0)) n.new_layer(name="lr2", size=3, spec=lr2_spec) pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3)) n.new_projn("proj1", "lr1", "lr2", pr1_spec) pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1]) n.new_projn("proj2", "lr1", "lr2", pr2_spec) n.force_layer("lr1", [1]) for i in range(100): n.cycle() logs = n.logs("cycle", "lr2") acts = logs[logs.time == 99]["act"] assert (acts > 0.8).sum() == 2
def test_projn_pre_mask_truncates_if_it_is_too_long() -> None: pre = lr.Layer("lr1", size=1) post = lr.Layer("lr2", size=1) spec = sp.ProjnSpec(pre_mask=(True, False), dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) assert projn.wts[0, 0] == 1 assert projn.wts.shape == (1, 1)
def test_projns_can_be_sparse() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) spec = sp.ProjnSpec(dist=rn.Scalar(1.0), sparsity=0.5) projn = pr.Projn("proj", pre, post, spec) num_on = projn.wts.sum() assert num_on == 2.0
def test_projn_one_to_one_connectivity_pattern_is_correct() -> None: pre = lr.Layer("lr1", size=3) post = lr.Layer("lr2", size=3) projn = pr.Projn( "proj", pre, post, sp.ProjnSpec(projn_type="one_to_one", dist=rn.Scalar(1.0))) assert (projn.wts == torch.eye(3)).all()
def test_you_can_log_projection_weights() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) projn = pr.Projn("proj", pre, post, spec=sp.ProjnSpec(projn_type="one_to_one", dist=rn.Scalar(0.5))) expected = {"pre_unit": [0, 1], "post_unit": [0, 1], "conn_wt": [0.5, 0.5]} assert projn.observe_parts_attr("conn_wt") == expected
def test_projn_can_mask_post_layer_units() -> None: pre = lr.Layer("lr1", size=2) post = lr.Layer("lr2", size=2) mask = (True, False) spec = sp.ProjnSpec(post_mask=mask, dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) for i in range(post.size): for j in range(pre.size): if mask[i]: assert projn.wts[i, j] == 1 else: assert projn.wts[i, j] == 0
def test_projn_pre_mask_tiles_if_it_is_too_short() -> None: pre = lr.Layer("lr1", size=4) post = lr.Layer("lr2", size=2) mask = (True, False) spec = sp.ProjnSpec(pre_mask=mask, dist=rn.Scalar(1)) projn = pr.Projn("proj", pre, post, spec) for i in range(post.size): for j in range(pre.size): if mask[j % 2]: assert projn.wts[i, j] == 1 else: assert projn.wts[i, j] == 0
class ProjnSpec(Spec): """Spec for `Projn` objects.""" # The probability distribution from which the connection weights will be # drawn dist: rand.Distribution = rand.Scalar(0.5) # Selects which pre layer units will be included in the projection # If the length is less than the number of units in the pre_layer, it will # be tiled. If the length is more, it will be truncated. pre_mask: Iterable[bool] = (True, ) # Selects which post layer units will be included in the projection # If the length is less than the number of units in the pre_layer, it will # be tiled. If the length is more, it will be truncated. post_mask: Iterable[bool] = (True, ) # Sparsity of the connection (i.e. the percentage of active connections.) sparsity: float = 1.0 def validate(self) -> None: # pylint: disable=W0235 """Extends `Spec.validate`.""" if not isinstance(self.dist, rand.Distribution): raise ValidationError("{0} is not a valid " "distribution.".format(self.dist)) self.assert_in_range("sparsity", low=0.0, high=1.0) super().validate()
def test_projn_can_specify_its_weight_distribution() -> None: pre = lr.Layer("lr1", size=3) post = lr.Layer("lr2", size=3) projn = pr.Projn("proj", pre, post, sp.ProjnSpec(dist=rn.Scalar(7))) assert (projn.wts == 7).all()
class ProjnSpec(ObservableSpec): """Spec for `Projn` objects.""" # The probability distribution from which the connection weights will be # drawn dist: rand.Distribution = rand.Scalar(0.5) # Selects which pre layer units will be included in the projection # If the length is less than the number of units in the pre_layer, it will # be tiled. If the length is more, it will be truncated. pre_mask: Iterable[bool] = (True, ) # Selects which post layer units will be included in the projection # If the length is less than the number of units in the pre_layer, it will # be tiled. If the length is more, it will be truncated. post_mask: Iterable[bool] = (True, ) # Sparsity of the connection (i.e. the percentage of active connections.) sparsity: float = 1.0 # Set special type of projection. One of ["full", "one_to_one"]. projn_type = "full" # Absolute net input scaling weight wt_scale_abs: float = 1.0 # Relative net input scaling weight (relative to other projections # terminating in the same layer) wt_scale_rel: float = 1.0 # Learning rate lrate = 0.02 # Mixing constant determining how much learning is hebbian. # See Emergent docs. thr_l_mix = 0.1 # Flag controlling whether thr_l_mix is modulated by cos_diff_avg cos_diff_thr_l_mix = False # Modulate the learn rate by cos_diff_avg? cos_diff_lrate = False # Gain for sigmoidal weight contrast enhancement sig_gain = 6 # Offset for sigmoidal weight contrast enhancement sig_offset = 1 # Minus phase minus_phase = events.MinusPhase # Plus phase plus_phase = events.PlusPhase @property def _valid_attrs_to_log(self) -> Iterable[str]: """Overrides `ObservableSpec._valid_attrs_to_log`.""" # Valid attributes to log on every cycle # When adding any loggable attribute or property to this list, # update Projn._whole_attrs or Projn._parts_attrs as appropriate # (we represent in two places to avoid a circular dependency) return ("conn_wt", "conn_fwt", "cos_diff_avg") def validate(self) -> None: # pylint: disable=W0235 """Extends `Spec.validate`.""" super().validate() if not isinstance(self.dist, rand.Distribution): raise ValidationError("{0} is not a valid " "distribution.".format(self.dist)) self.assert_in_range("sparsity", low=0.0, high=1.0) valid_projn_types = ["one_to_one", "full"] if self.projn_type not in valid_projn_types: raise ValidationError( "Projn type {0} not one of [\"one_to_one\", \"full\"]".format( self.projn_type)) self.assert_in_range("wt_scale_abs", 0, float("Inf")) self.assert_in_range("wt_scale_rel", 0, float("Inf")) self.assert_in_range("lrate", 0, float("Inf")) self.assert_in_range("thr_l_mix", 0, float("Inf")) self.assert_in_range("cos_diff_thr_l_mix", 0, float("Inf")) self.assert_in_range("sig_gain", 0, float("Inf")) self.assert_sane_float("sig_offset") if self.minus_phase == self.plus_phase: raise ValidationError( "Minus and plus phase cannot both be {0}".format( self.minus_phase))
def test_scalar_is_always_equal_to_its_value() -> None: dist = rn.Scalar(3) x = torch.Tensor(10) dist.fill(x) assert (x == 3).all()