def test_layer_hard_clamping_respects_clamp_max() -> None: layer = lr.Layer(name="in", size=3, spec=sp.LayerSpec(clamp_max=0.5)) layer.hard_clamp(act_ext=[0, 1]) layer.handle(ev.HardClamp(layer_name="lr1", acts=[0, 1])) expected = [0, 0.5, 0] for i in range(len(expected)): assert math.isclose(layer.units.act[i], expected[i], abs_tol=1e-6)
def test_kwta_suppresses_all_but_k_units() -> None: n = net.Net() n.new_layer(name="lr1", size=1) lr2_spec = sp.LayerSpec( inhibition_type="kwta", k=2, log_on_cycle=("unit_act", ), unit_spec=sp.UnitSpec(adapt_dt=0, spike_gain=0)) n.new_layer(name="lr2", size=3, spec=lr2_spec) pr1_spec = sp.ProjnSpec(dist=rn.Scalar(0.3)) n.new_projn("proj1", "lr1", "lr2", pr1_spec) pr2_spec = sp.ProjnSpec(dist=rn.Scalar(0.5), post_mask=[0, 1, 1]) n.new_projn("proj2", "lr1", "lr2", pr2_spec) n.force_layer("lr1", [1]) for i in range(100): n.cycle() logs = n.logs("cycle", "lr2") acts = logs[logs.time == 99]["act"] assert (acts > 0.8).sum() == 2
def test_net_can_pause_and_resume_logging() -> None: n = net.Net() n.new_layer( "layer1", 2, spec=specs.LayerSpec(log_on_cycle=( "unit_act", "avg_act", ))) for i in range(2): n.cycle() n.pause_logging() for i in range(2): n.cycle() n.resume_logging() for i in range(2): n.cycle() parts_time = torch.Tensor(n.logs("cycle", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("cycle", "layer1").whole["time"]) assert list(parts_time.size()) == [8] assert list(whole_time.size()) == [4] assert (parts_time == torch.Tensor([0, 0, 1, 1, 4, 4, 5, 5])).all() assert (whole_time == torch.Tensor([0, 1, 4, 5])).all()
def test_net_trial_log_pausing_and_resuming() -> None: n = net.Net() n.new_layer( "layer1", 2, spec=specs.LayerSpec(log_on_trial=( "unit_act", "avg_act", ))) n.phase_cycle(phase=events.MinusPhase, num_cycles=5) n.phase_cycle(phase=events.PlusPhase, num_cycles=5) n.end_trial() n.pause_logging("trial") n.phase_cycle(phase=events.MinusPhase, num_cycles=5) n.phase_cycle(phase=events.PlusPhase, num_cycles=5) n.end_trial() n.resume_logging("trial") n.phase_cycle(phase=events.MinusPhase, num_cycles=5) n.phase_cycle(phase=events.PlusPhase, num_cycles=5) n.end_trial() parts_time = torch.Tensor(n.logs("trial", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("trial", "layer1").whole["time"]) assert list(parts_time.size()) == [4] assert list(whole_time.size()) == [2] assert (parts_time == torch.Tensor([0, 0, 2, 2])).all() assert (whole_time == torch.Tensor([0, 2])).all()
def __init__(self, name: str, size: int, spec: specs.LayerSpec = None) -> None: self._name = name self.size = size if spec is None: self._spec = specs.LayerSpec() else: self._spec = spec self.units = unit.UnitGroup(size=size, spec=self.spec.unit_spec) # Feedback inhibition self.fbi = 0.0 # Global inhibition self.gc_i = 0.0 # Is the layer activation clamped? self.clamped = False # Is this a hidden layer? (i.e. has never been clamped) self.hidden = True # Set k units for inhibition self.k = max(1, int(round(self.size * self.spec.kwta_pct))) # Desired clamping values self.act_ext = torch.Tensor(self.size).zero_() # Last plus phase activation self.acts_p = torch.Tensor(self.size).zero_() # Last minus phase activation self.acts_m = torch.Tensor(self.size).zero_() # Cosine similarity between acts_p and acts_m self.cos_diff = 0.0 # Cosine similiarity between acts_p and acts_m, integrated over trials self.cos_diff_avg = 0.0 # The following two buffers are filled every time self.add_input() is # called, and reset at the end of self.activation_cycle() # Net input (excitation) input buffer. For every cycle, we # store the layer inputs here. Once we have all the inputs, we # normalize by wt_scale_rel_sum and send to the unit group. self.input_buffer = torch.Tensor(self.size).zero_() # Sum of the wt_scale_rel parameters for each projection terminating in # this layer. We use this to normalize the inputs before propagating to # unit group self.wt_scale_rel_sum = 0.0 # When adding any loggable attribute or property to these lists, update # specs.LayerSpec._valid_log_on_cycle (we represent in two places to # avoid a circular dependency) whole_attrs: List[str] = ["avg_act", "avg_net", "cos_diff_avg", "fbi"] parts_attrs: List[str] = [ "unit_net", "unit_net_raw", "unit_gc_i", "unit_act", "unit_i_net", "unit_i_net_r", "unit_v_m", "unit_v_m_eq", "unit_adapt", "unit_spike" ] super().__init__(whole_attrs=whole_attrs, parts_attrs=parts_attrs)
def test_attrs_to_log_gets_the_attrs_to_log_by_frequency() -> None: spec = sp.LayerSpec(log_on_cycle=("unit_act", ), log_on_trial=("unit_v_m", ), log_on_epoch=("unit_spike", ), log_on_batch=("unit_i_net", )) assert spec.attrs_to_log(ev.CycleFreq) == ("unit_act", ) assert spec.attrs_to_log(ev.TrialFreq) == ("unit_v_m", ) assert spec.attrs_to_log(ev.EpochFreq) == ("unit_spike", ) assert spec.attrs_to_log(ev.BatchFreq) == ("unit_i_net", )
def test_you_can_retrieve_the_logs_for_a_layer() -> None: n = net.Net() n.new_layer(name="layer1", size=3, spec=specs.LayerSpec(log_on_cycle=("avg_act", ), log_on_trial=("avg_act", ), log_on_epoch=("avg_act", ), log_on_batch=("avg_act", ))) n.plus_phase_cycle(1) n.end_epoch() n.end_batch() for freq in ("cycle", "trial", "epoch", "batch"): assert "avg_act" in n.logs(freq, "layer1").whole.columns
def __init__(self, name: str, size: int, spec: specs.LayerSpec = None) -> None: self.size = size if spec is None: self.spec = specs.LayerSpec() else: self.spec = spec self.units = [unit.Unit(self.spec.unit_spec) for _ in range(size)] self.fbi = 0.0 super().__init__(name)
def test_projn_can_uninhibit_flush() -> None: pre = lr.Layer("lr1", size=1, spec=sp.LayerSpec(clamp_max=1)) post = lr.Layer("lr2", size=1) projn = pr.Projn("proj", pre, post) pre.clamp(act_ext=[1]) projn.handle(ev.InhibitProjns("proj")) projn.flush() projn.handle(ev.UninhibitProjns("proj")) projn.flush() assert post.input_buffer == 0.5
def test_layer_soft_clamping_equivalent_to_input() -> None: layer1 = lr.Layer(name="test", size=3) layer2 = lr.Layer(name="alt", size=3, spec=sp.LayerSpec(clamp_max=1)) layer2.clamp([0, 1, 0], hard=False) for i in range(50): layer1.add_input(torch.Tensor([0, 1, 0])) layer1.activation_cycle() layer2.activation_cycle() print(layer1.units.act) print(layer2.units.act) for i in range(3): assert math.isclose(layer1.units.act[i], layer2.units.act[i], abs_tol=1e-6)
def test_network_can_be_retrieved_and_continue_logging() -> None: n = net.Net() n.new_layer("layer1", 2, spec=specs.LayerSpec(log_on_cycle=( "unit_act", "avg_act", ))) for i in range(2): n.cycle() n.pause_logging() location = "tests/mynet.pkl" n.save(location) m = net.Net() m.load(filename=location) before_parts_n = n.logs("cycle", "layer1").parts before_parts_m = m.logs("cycle", "layer1").parts before_whole_n = n.logs("cycle", "layer1").whole before_whole_m = m.logs("cycle", "layer1").whole assert np.all((before_parts_n == before_parts_m).values) assert np.all((before_whole_n == before_whole_m).values) for i in range(2): m.cycle() m.resume_logging() for i in range(2): m.cycle() after_parts_time = torch.Tensor(m.logs("cycle", "layer1").parts["time"]) after_whole_time = torch.Tensor(m.logs("cycle", "layer1").whole["time"]) assert list(after_parts_time.size()) == [8] assert list(after_whole_time.size()) == [4] assert (after_parts_time == torch.Tensor([0, 0, 1, 1, 4, 4, 5, 5])).all() assert (after_whole_time == torch.Tensor([0, 1, 4, 5])).all()
def test_net_batch_log_pausing_and_resuming() -> None: n = net.Net() n.new_layer("layer1", 2, spec=specs.LayerSpec(log_on_batch=( "unit_act", "avg_act", ))) n.end_batch() n.pause_logging("batch") n.end_batch() n.resume_logging("batch") n.end_batch() parts_time = torch.Tensor(n.logs("batch", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("batch", "layer1").whole["time"]) assert list(parts_time.size()) == [4] assert list(whole_time.size()) == [2] assert (parts_time == torch.Tensor([0, 0, 2, 2])).all() assert (whole_time == torch.Tensor([0, 2])).all()
def test_it_should_check_for_invalid_log_on_batch_attrs() -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(log_on_batch=("whales", )).validate()
def test_every_valid_log_on_cycle_attribute_can_be_logged() -> None: valid_attrs = sp.LayerSpec()._valid_attrs_to_log lr1 = lr.Layer("lr1", 3) for attr in valid_attrs: lr1.validate_attr(attr)
def test_it_should_validate_the_unit_spec() -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(unit_spec=sp.UnitSpec(vm_dt=-1)).validate()
def test_it_should_validate_global_inhibition_multiplier(f) -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(gi=f).validate()
def test_it_should_validate_feedback_integration_constant(f) -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(fb_dt=f).validate()
def test_it_should_validate_feedback_inhibition_multiplier(f) -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(fb=f).validate()
def test_it_should_validate_feedforward_inhibition_multiplier(f): with pytest.raises(sp.ValidationError): sp.LayerSpec(ff=f).validate()
def test_it_should_check_for_invalid_clamp_max(f) -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(clamp_max=f).validate()
def test_layer_init_uses_the_spec_you_pass_it() -> None: spec = sp.LayerSpec() layer = lr.Layer(name="in", spec=spec, size=1) assert layer.spec is spec
def test_layer_should_be_able_to_update_its_units_kwta_avg_inhibition( ) -> None: layer_spec = sp.LayerSpec(inhibition_type="kwta_avg") layer = lr.Layer(name="in", size=3, spec=layer_spec) layer.update_inhibition()
def test_a_new_layer_validates_its_spec() -> None: n = net.Net() with pytest.raises(specs.ValidationError): n.new_layer("layer1", 3, spec=specs.LayerSpec(integ=-1))
def test_you_can_retrieve_the_logs_for_a_layer(): n = net.Net() n.new_layer("layer1", 3, spec=specs.LayerSpec(log_on_cycle=("avg_act", ))) n.cycle() assert "avg_act" in n.logs("cycle", "layer1").columns
def test_observing_unlogged_attr_raises_error_if_obj_not_observable() -> None: n = net.Net() n.new_layer("layer1", 1, spec=specs.LayerSpec(log_on_cycle=("avg_act", ))) with pytest.raises(ValueError): n.observe("layer1_cycle_logger", "cos_diff_avg")
def test_layer_spec_validates_inhibition_type(f) -> None: if f not in ["kwta", "kwta_avg", "fffb", "none"]: with pytest.raises(sp.ValidationError): sp.LayerSpec(inhibition_type=f).validate() else: sp.LayerSpec(inhibition_type=f).validate()
def test_layer_spec_validates_kwta_pt(f) -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(kwta_pt=f).validate()
def test_layer_spec_validates_integ(f) -> None: with pytest.raises(sp.ValidationError): sp.LayerSpec(integ=f).validate()
def test_net_multiple_freq_log_pausing_and_resuming() -> None: n = net.Net() loggables = ("unit_act", "avg_act") n.new_layer( "layer1", 2, spec=specs.LayerSpec( log_on_cycle=loggables, log_on_epoch=loggables, log_on_trial=loggables, log_on_batch=loggables)) n.cycle() n.end_trial() n.end_epoch() n.end_batch() n.pause_logging("cycle", "trial", "epoch", "batch") n.cycle() n.end_trial() n.end_epoch() n.end_batch() n.resume_logging("cycle", "trial", "epoch", "batch") n.cycle() n.end_trial() n.end_epoch() n.end_batch() parts_time = torch.Tensor(n.logs("cycle", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("cycle", "layer1").whole["time"]) assert list(parts_time.size()) == [4] assert list(whole_time.size()) == [2] assert (parts_time == torch.Tensor([0, 0, 2, 2])).all() assert (whole_time == torch.Tensor([0, 2])).all() parts_time = torch.Tensor(n.logs("trial", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("trial", "layer1").whole["time"]) assert list(parts_time.size()) == [4] assert list(whole_time.size()) == [2] assert (parts_time == torch.Tensor([0, 0, 2, 2])).all() assert (whole_time == torch.Tensor([0, 2])).all() parts_time = torch.Tensor(n.logs("epoch", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("epoch", "layer1").whole["time"]) assert list(parts_time.size()) == [4] assert list(whole_time.size()) == [2] assert (parts_time == torch.Tensor([0, 0, 2, 2])).all() assert (whole_time == torch.Tensor([0, 2])).all() parts_time = torch.Tensor(n.logs("batch", "layer1").parts["time"]) whole_time = torch.Tensor(n.logs("batch", "layer1").whole["time"]) assert list(parts_time.size()) == [4] assert list(whole_time.size()) == [2] assert (parts_time == torch.Tensor([0, 0, 2, 2])).all() assert (whole_time == torch.Tensor([0, 2])).all()
def test_layer_spec_validates_k(f) -> None: if f < 1: with pytest.raises(sp.ValidationError): sp.LayerSpec(inhibition_type="kwta", k=f).validate()